from typing import List, Optional

import torch

from apadata.strategies.score_strategy import ScoreStrategyEnum
from apadata.strategies.score_strategy_factory import ScoreStrategyFactory
from apadata.text_processors.evaluators.evaluator import Evaluator
from apadata.text_processors.text_similarity import (
    TextSimilarity,
    load_industry_names_embs,
)


class SynonymyEvaluator(Evaluator):
    """Looks for highest or mean similarity score semantically speaking"""

    ts = TextSimilarity()

    def __init__(
        self,
        strategy: ScoreStrategyEnum = ScoreStrategyEnum(ScoreStrategyEnum.HIGHEST),
        industry_name_embs: Optional[torch.Tensor] = None,
        industry_names: Optional[List[str]] = None,
        context: str = "",
        connector: str = "in",
    ):
        self.strategy: ScoreStrategyEnum = strategy
        if not industry_name_embs:
            industry_name_embs = load_industry_names_embs()
        self.industry_name_embs = industry_name_embs or []  # type: ignore
        self.industry_names = industry_names or []
        self.context = context
        self.connector = connector

    def evaluate(self, keyword: str) -> float:
        if (self.industry_names and not self.context) or (
            not self.industry_names and self.context
        ):
            raise ValueError(
                "You must supply either both the industry_names and the "
                "context or none of them!"
            )
        # overrides the industry_name_embs if you provided some context
        if not self.industry_name_embs:
            self.industry_name_embs = torch.Tensor(
                [
                    SynonymyEvaluator.ts.encode_text(
                        f"{industry_name}" f"{self.connector}" f"{self.context}"
                    )
                    for industry_name in self.industry_names
                ]
            )

        sims: List[float] = [
            SynonymyEvaluator.ts.text_similarity(keyword, industry_name_emb)
            for industry_name_emb in self.industry_name_embs
        ]
        final_sim = ScoreStrategyFactory.create(self.strategy).calculate(sims)
        return float(final_sim)
