from typing import Any, Dict, List, Optional

from apadata.constants import DEFAULT_SCORE
from apadata.pipelines.pipeline import Pipeline
from apadata.pipelines.pipeline_context import PipelineContext
from apadata.text_processors import SpacyTextProcessor

from ..strategies import ScoreStrategyEnum, ScoreStrategyFactory
from . import (
    DifficultyEvaluatorModule,
    ElasticsearchEvaluatorModule,
    GeneralityEvaluatorModule,
    KeywordsFrequencyEvaluatorModule,
    LanguageFrequencyEvaluatorModule,
    LengthEvaluatorModule,
    LinguisticalEvaluatorModule,
    LocationEvaluatorModule,
    LowQualityKeywordsEvaluatorModule,
    SpellingEvaluatorModule,
    SynonymyEvaluatorModule,
    WikipediaEvaluatorModule,
)

out_of_interval_values: List[str] = []
uncalculated_criteria: List[str] = []

spell_checker_dict: Dict[str, Any] = {}
spacy_pipelines_dictionary: Dict[str, SpacyTextProcessor] = {}


class EvaluatorPipeline(Pipeline):
    """
    Class that performs the whole target industries extraction pipeline

    Parameters
    ----------
    context: PipelineContext
        Context of the pipeline which receives and passes along data across a
        pipeline enriching it with information from several modules
    """

    def __init__(
        self,
        context: PipelineContext,
    ):
        super().__init__(
            context=context,
            modules=[
                DifficultyEvaluatorModule(),
                KeywordsFrequencyEvaluatorModule(),
                GeneralityEvaluatorModule(),
                ElasticsearchEvaluatorModule(),
                LanguageFrequencyEvaluatorModule(),
                LengthEvaluatorModule(),
                LinguisticalEvaluatorModule(),
                LocationEvaluatorModule(),
                LowQualityKeywordsEvaluatorModule(),
                SpellingEvaluatorModule(),
                SynonymyEvaluatorModule(),
                WikipediaEvaluatorModule(),
            ],
        )
        self.relevance_score: Optional[float] = None
        self.min_score_value = 0.0
        self.max_score_value = 1.0
        self.feature_vector: List[float] = []
        self.out_of_interval_values: List[float] = []
        self.uncalculated_criteria: List[str] = []

    def post_process(self, context: PipelineContext) -> Any:
        scores_values = {
            "language_frequency_score": context.search_field(
                "language_frequency_score"
            ),
            "wikipedia_score": context.search_field("wikipedia_score"),
            "synonymy_score": context.search_field("synonymy_score"),
            "spelling_score": context.search_field("spelling_score"),
            "low_quality_keywords_score": context.search_field(
                "low_quality_keywords_score"
            ),
            "location_score": context.search_field("location_score"),
            "linguistical_score": context.search_field("linguistical_score"),
            "length_score": context.search_field("length_score"),
            "keywords_frequency_score": context.search_field(
                "keywords_frequency_score"
            ),
            "elasticsearch_score": context.search_field("elasticsearch_score"),
            "generality_score": context.search_field("generality_score"),
            "difficulty_score": context.search_field("difficulty_score"),
        }

        scores_weights = {
            "difficulty_score": 0.010176,
            "generality_score": 0.009502,
            "keywords_frequency_score": 0.008016,
            "language_frequency_score": 3.1e-05,
            "length_score": 0.010176,
            "linguistical_score": 0.009382,
            "location_score": 0.000627,
            "low_quality_keywords_score": 0.009366,
            "synonymy_score": 0.0033,
            "wikipedia_score": 0.00562,
            "spelling_score": 0.0,
            "elasticsearch_score": 0.0,
        }
        context.add("scores_values", scores_values)
        context.add("scores_weights", scores_weights)

        self.detect_out_of_interval_values(scores_values)
        self.detect_uncalculated_criteria(scores_values)
        self.feature_vector = [value for value in scores_values.values() if value]

        strategy = context.search_field("strategy")

        if strategy == ScoreStrategyEnum.MEAN:
            array = [value for value in scores_values.values() if value is not None]
        elif strategy == ScoreStrategyEnum.WEIGHTED_MEAN:
            array = [
                (scores_values[metric], scores_weights[metric])
                for metric in sorted(scores_weights.keys())
                if scores_values[metric] is not None
                and scores_weights[metric] is not None
            ]
        else:
            raise ValueError(f"Wrong strategy: {strategy}")

        relevance_score = ScoreStrategyFactory.create(strategy).calculate(array)

        self.relevance_score = round(relevance_score, 4)

        return {"relevance_score": self.relevance_score}

    def detect_out_of_interval_values(
        self, scores_values: Dict[str, float]
    ) -> List[float]:
        for (_, result) in scores_values.items():
            if result < self.min_score_value or result > self.max_score_value:
                self.out_of_interval_values.append(result)
        return self.out_of_interval_values

    def detect_uncalculated_criteria(
        self, scores_values: Dict[str, float]
    ) -> List[str]:
        for (metric_name, result) in scores_values.items():
            if result == DEFAULT_SCORE:
                self.uncalculated_criteria.append(metric_name)
        return self.uncalculated_criteria
