from typing import Any, List, Set

from itertools import chain

from nltk import download
from nltk.corpus import wordnet as wn
from nltk.corpus.reader import WordNetError

from apadata.text_processors.evaluators.evaluator import Evaluator

FULL_POINT = 1.0
HALF_POINT = FULL_POINT / 2
NO_POINTS = 0.0
SCALE_FACTOR = 1000


class GeneralityEvaluator(Evaluator):
    """Evaluates the commonality of a keyword"""

    def evaluate(self, keyword: str, **kwargs: Any) -> float:
        keyword_tokens = keyword.split()
        sum_scores = NO_POINTS
        for token in keyword_tokens:
            hypo_count = self.count_hyponyms(token)
            hyper_count = self.count_hypernyms(token)
            if not hypo_count or hypo_count < hyper_count:
                sum_scores += FULL_POINT
            elif hypo_count == hyper_count:
                sum_scores += HALF_POINT
            else:
                sum_scores += hyper_count / hypo_count
        return sum_scores / len(keyword_tokens)

    def get_synsets(self, word: str, num_synsets: int) -> List[Any]:
        num_synsets = 1 if not num_synsets else num_synsets
        try:
            try:
                word_synsets = wn.synsets(word)[:num_synsets]
            except LookupError:
                download("wordnet")
                word_synsets = wn.synsets(word)[:num_synsets]
        except (KeyError, WordNetError):
            word_synsets = []
        return list(word_synsets)

    def get_hyponyms(self, word: str, num_synsets: int = 1) -> Set[str]:
        word_synsets = self.get_synsets(word, num_synsets)
        hyponym_set = []
        for word_synset in word_synsets:
            hyponym_set += list(
                map(
                    lambda synset: list(synset.lemma_names()),
                    word_synset.closure(lambda s: s.hyponyms()),
                )
            )
        return set(chain.from_iterable(hyponym_set))

    def get_hypernyms(self, word: str, num_synsets: int = 1) -> Set[str]:
        word_synsets = self.get_synsets(word, num_synsets)
        hypernym_set = []
        for word_synset in word_synsets:
            hypernym_set += list(
                map(
                    lambda synset: list(synset.lemma_names()),
                    word_synset.closure(lambda s: s.hypernyms()),
                )
            )
        return set(chain.from_iterable(hypernym_set))

    def count_hyponyms(self, keyword: str) -> float:
        hyponyms = self.get_hyponyms(keyword)
        num_hyponyms = len(hyponyms)
        scaled_num_hyponyms = num_hyponyms / SCALE_FACTOR
        return scaled_num_hyponyms

    def count_hypernyms(self, keyword: str) -> float:
        hypernyms = self.get_hypernyms(keyword)
        num_hypernyms = len(hypernyms)
        scaled_num_hypernyms = num_hypernyms / SCALE_FACTOR
        return scaled_num_hypernyms
