from typing import Any, Dict

from apadata.pipelines.pipeline_context import PipelineContext
from apadata.pipelines.pipeline_module import PipelineModule
from apadata.spacy.constants import KeywordsProcessingMode
from apadata.text_processors.keywords.enrichers.similarity_finder_enricher import (
    SimilarityFinderEnricher,
)
from apadata.text_processors.keywords.reducers.similarity_finder_reducer import (
    SimilarityFinderReducer,
)

Input = Dict[str, Any]
Output = Dict[str, Any]


class SimilarityFinderModule(PipelineModule[Input, Output]):
    """
    Finds all tuples of 2 or more keywords that are too similar between one
    another above a fixed threshold and removes all those keywords except one randomly
    """

    def __init__(
        self,
        name: str = "similarity_finder",
        mode: str = KeywordsProcessingMode.REDUCE,
        **kwargs: Any,
    ):
        super().__init__(
            name,
            **kwargs,
        )
        self.mode = mode

    def run(self, context: PipelineContext) -> Output:
        keyword = context.search_field("keyword")
        if self.mode == KeywordsProcessingMode.REDUCE:
            SimilarityFinderReducer(keyword).process()
        elif self.mode == KeywordsProcessingMode.ENRICH:
            SimilarityFinderEnricher(keyword).process()
        else:
            raise ValueError(f"Wrong mode given: {self.mode}!")
        return dict(context.payload)
