from typing import Any, List

import string
from itertools import chain
from statistics import mean

from spacy.lang.en import stop_words

from apadata.loaders import CSVLoader
from apadata.text_processors.evaluators.evaluator import Evaluator

black_list_csv_data = CSVLoader("text_processors/red_list.csv").load()
BLACKLISTED_KEYWORDS: List[str] = list(
    map(lambda black_row: black_row["text"].lower(), black_list_csv_data)
)
WHITELISTED_CHARACTERS = set(
    chain(string.ascii_lowercase, string.ascii_uppercase, string.digits, " ,.&/")
)


class LowQualityKeywordsEvaluator(Evaluator):
    """Evaluates a keyword based on its part of speech, dependency trees and named
    entities"""

    def evaluate(self, keyword: str, **kwargs: Any) -> float:
        keyword_tokens = [w.lower() for w in keyword.split()]
        score_list = [
            self.blacklisted_words_score(keyword_tokens, BLACKLISTED_KEYWORDS),
            self.special_characters_score(keyword),
            self.stopwords_score(keyword_tokens),
        ]

        return mean(score_list)

    def blacklisted_words_score(
        self, keyword_tokens: List[str], black_list: List[str]
    ) -> float:
        num_blacklisted = len(
            list(filter(lambda word: word in black_list, keyword_tokens))
        )
        blacklisted_ratio = num_blacklisted / len(keyword_tokens)

        return 1 - blacklisted_ratio

    def stopwords_score(self, keyword_tokens: List[str]) -> float:
        num_stopwords = len(
            list(filter(lambda word: word in stop_words.STOP_WORDS, keyword_tokens))
        )
        stopword_ratio = num_stopwords / len(keyword_tokens)

        return 1 - stopword_ratio

    def special_characters_score(self, keyword: str) -> float:
        for character in keyword:
            if character not in WHITELISTED_CHARACTERS:
                return 0.0
        return 1.0
