from typing import Any

from statistics import mean

import nltk
from nltk.tokenize import word_tokenize
from textstat import (
    automated_readability_index,
    coleman_liau_index,
    dale_chall_readability_score,
    difficult_words,
    flesch_kincaid_grade,
    flesch_reading_ease,
    gulpease_index,
    gunning_fog,
    linsear_write_formula,
    osman,
    smog_index,
)
from textstat.textstat import textstatistics

from apadata.text_processors.evaluators.evaluator import Evaluator


class DifficultyEvaluator(Evaluator):
    """Uses text difficulty features in order to tell how relevant is a keyword"""

    def __init__(self):
        nltk.download("punkt")
        super().__init__()

    def evaluate(self, keyword: str, **kwargs: Any) -> float:
        text_diff_features = [
            flesch_reading_ease(keyword),
            flesch_kincaid_grade(keyword),
            smog_index(keyword),
            coleman_liau_index(keyword),
            automated_readability_index(keyword),
            dale_chall_readability_score(keyword),
            difficult_words(keyword),
            linsear_write_formula(keyword),
            gunning_fog(keyword),
            gulpease_index(keyword),
            osman(keyword),
            self.poly_syllable_count(keyword),
        ]
        return float(mean(text_diff_features))

    def syllables_count(self, word: str) -> int:
        return int(textstatistics().syllable_count(word))

    def poly_syllable_count(self, keyword: str) -> float:
        keyword_tokens = word_tokenize(keyword)
        poly_syllable_words = [
            word for word in keyword_tokens if self.syllables_count(word) >= 3
        ]
        return len(poly_syllable_words) / len(keyword_tokens)
