from typing import Dict, List, Union

import time

import cachetools
import numpy as np
import torch
from sentence_transformers import util

from apadata.models import ExternalIndustry
from apadata.text_processors import TextProcessor
from apadata.text_processors.text_embedder_processor import TextEmbedderProcessor


@cachetools.func.ttl_cache(
    maxsize=10_000, ttl=86_400, timer=time.monotonic, typed=False
)
def load_industry_names_embs():
    industry_names = [row.name for row in ExternalIndustry.load_all()]
    text_embedder = TextEmbedderProcessor("")
    embeddings = []
    for text in industry_names:
        text_embedder.text = text
        embeddings.append(text_embedder.process())
    return embeddings


class TextSimilarity(TextProcessor):
    """Class that computes text similarity"""

    text_to_emb: Dict[str, torch.Tensor] = {}

    def __init__(self, text: str = "", similar_text: str = ""):
        super().__init__(text)
        self.similar_text = similar_text
        self.text_embedder = TextEmbedderProcessor("")

    def process(self) -> float:
        return self.text_similarity(self.text, self.similar_text)

    def encode_text(self, text: str) -> torch.Tensor:
        if text not in TextSimilarity.text_to_emb:
            self.text_embedder.text = text
            TextSimilarity.text_to_emb[text] = self.text_embedder.process()
        return torch.Tensor(TextSimilarity.text_to_emb[text])

    def _get_embedding(
        self, text: Union[str, List[str], torch.Tensor, np.ndarray]  # type: ignore
    ) -> torch.Tensor:
        """Get the embedding for a given text or tensor."""
        if isinstance(text, str):
            return self.encode_text(text)
        if isinstance(text, list):
            return torch.Tensor(text)
        if isinstance(text, np.ndarray):
            return torch.Tensor(text)
        if isinstance(text, torch.Tensor):
            return text
        raise ValueError("Wrong type")

    def text_similarity(
        self, text_a: Union[str, torch.Tensor], text_b: Union[str, torch.Tensor]
    ) -> float:
        embedding_a = self._get_embedding(text_a)
        embedding_b = self._get_embedding(text_b)
        return float(util.pytorch_cos_sim(embedding_a, embedding_b)[0][0].item())
