from typing import List

import string
from itertools import chain
from statistics import mean

from cachetools import TTLCache, cached
from spacy.lang.en import stop_words

from apadata.models import RedList
from apadata.text_processors.evaluators.evaluator import Evaluator

WHITELISTED_CHARACTERS = set(
    chain(string.ascii_lowercase, string.ascii_uppercase, string.digits, " ,.&/")
)


@cached(cache=TTLCache(maxsize=10_000, ttl=1_800))
def load_redlist_keywords():
    return list(map(lambda black_row: black_row.name, list(RedList.objects.all())))


class LowQualityKeywordsEvaluator(Evaluator):
    """Evaluates a keyword based on its part of speech, dependency trees and named
    entities"""

    def __init__(self):
        self.blacklisted_keywords: List[str] = load_redlist_keywords()

    def evaluate(self, keyword: str) -> float:
        keyword_tokens = [w.lower() for w in keyword.split()]
        score_list = [
            self.blacklisted_words_score(keyword_tokens, self.blacklisted_keywords),
            self.special_characters_score(keyword),
            self.stopwords_score(keyword_tokens),
        ]

        return mean(score_list)

    def blacklisted_words_score(
        self, keyword_tokens: List[str], black_list: List[str]
    ) -> float:
        num_blacklisted = len(
            list(filter(lambda word: word in black_list, keyword_tokens))
        )
        blacklisted_ratio = num_blacklisted / len(keyword_tokens)

        return 1 - blacklisted_ratio

    def stopwords_score(self, keyword_tokens: List[str]) -> float:
        num_stopwords = len(
            list(filter(lambda word: word in stop_words.STOP_WORDS, keyword_tokens))
        )
        stopword_ratio = num_stopwords / len(keyword_tokens)

        return 1 - stopword_ratio

    def special_characters_score(self, keyword: str) -> float:
        for character in keyword:
            if character not in WHITELISTED_CHARACTERS:
                return 0.0
        return 1.0
