from statistics import mean

from nltk.tokenize import word_tokenize
from textstat import (
    automated_readability_index,
    coleman_liau_index,
    dale_chall_readability_score,
    difficult_words,
    flesch_kincaid_grade,
    flesch_reading_ease,
    gulpease_index,
    gunning_fog,
    linsear_write_formula,
    osman,
    smog_index,
)
from textstat.textstat import textstatistics

from apadata.text_processors.evaluators.evaluator import Evaluator


class DifficultyEvaluator(Evaluator):
    """Uses text difficulty features in order to tell how relevant is a keyword"""

    def evaluate(self, keyword: str) -> float:
        scaling_factors = {
            "flesch_reading_ease": 100.0,
            "flesch_kincaid_grade": 20.0,
            "smog_index": 20.0,
            "coleman_liau_index": 20.0,
            "automated_readability_index": 20.0,
            "dale_chall_readability_score": 10.0,
            "difficult_words": 1.0,
            "linsear_write_formula": 12.0,
            "gunning_fog": 20.0,
            "gulpease_index": 100.0,
            "osman": 1.0,
            "poly_syllable_count": 1.0,
        }
        text_diff_features = [
            flesch_reading_ease(keyword) / scaling_factors["flesch_reading_ease"],
            flesch_kincaid_grade(keyword) / scaling_factors["flesch_kincaid_grade"],
            smog_index(keyword) / scaling_factors["smog_index"],
            coleman_liau_index(keyword) / scaling_factors["coleman_liau_index"],
            automated_readability_index(keyword)
            / scaling_factors["automated_readability_index"],
            dale_chall_readability_score(keyword)
            / scaling_factors["dale_chall_readability_score"],
            difficult_words(keyword) / scaling_factors["difficult_words"],
            linsear_write_formula(keyword) / scaling_factors["linsear_write_formula"],
            gunning_fog(keyword) / scaling_factors["gunning_fog"],
            gulpease_index(keyword) / scaling_factors["gulpease_index"],
            osman(keyword) / scaling_factors["osman"],
            self.poly_syllable_count(keyword) / scaling_factors["poly_syllable_count"],
        ]
        text_diff_features = [max(feature, 1.0) for feature in text_diff_features]
        return min(1.0, max(0.0, float(mean(text_diff_features))))

    def syllables_count(self, word: str) -> int:
        return int(textstatistics().syllable_count(word))

    def poly_syllable_count(self, keyword: str) -> float:
        keyword_tokens = word_tokenize(keyword)
        poly_syllable_words = [
            word for word in keyword_tokens if self.syllables_count(word) >= 3
        ]
        return len(poly_syllable_words) / len(keyword_tokens)
