from typing import Any, Dict, Optional, Union

import torch
from sentence_transformers import SentenceTransformer, util

from apadata.text_processors import TextProcessor


def load_model():
    model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
    if torch.cuda.is_available():
        model.cuda()
    return model


class TextSimilarity(TextProcessor):
    """Class that computes text similarity"""

    text_to_emb: Dict[str, torch.Tensor] = {}
    model: Optional[Any] = None

    def __init__(self):
        super().__init__("")

    def encode_text(self, text: str) -> torch.Tensor:
        if not TextSimilarity.model:
            TextSimilarity.model = load_model()
        if text not in TextSimilarity.text_to_emb:
            TextSimilarity.text_to_emb[text] = TextSimilarity.model.encode(
                text, convert_to_tensor=False
            ).astype("float")
        return torch.Tensor(TextSimilarity.text_to_emb[text])

    def _get_embedding(self, text: Union[str, torch.Tensor]) -> torch.Tensor:
        """Get the embedding for a given text or tensor."""
        if isinstance(text, str):
            return self.encode_text(text)
        if isinstance(text, torch.Tensor):
            return text
        raise ValueError("Wrong type")

    def text_similarity(
        self, text_a: Union[str, torch.Tensor], text_b: Union[str, torch.Tensor]
    ) -> float:
        embedding_a = self._get_embedding(text_a)
        embedding_b = self._get_embedding(text_b)
        return float(util.pytorch_cos_sim(embedding_a, embedding_b)[0][0].item())
