from typing import Any, Dict

from apadata.pipelines.pipeline_context import PipelineContext
from apadata.pipelines.pipeline_module import PipelineModule
from apadata.spacy.constants import KeywordsProcessingMode
from apadata.text_processors.keywords.enrichers.lemma_finder_enricher import (
    LemmaFinderEnricher,
)
from apadata.text_processors.keywords.reducers.lemma_finder_reducer import (
    LemmaFinderReducer,
)

Input = Dict[str, Any]
Output = Dict[str, Any]


class LemmaFinderModule(PipelineModule[Input, Output]):
    """
    PipelineModule class for the LemmaFinder TextProcessor
    """

    def __init__(
        self,
        name: str = "lemma_finder",
        mode: str = KeywordsProcessingMode.REDUCE,
        **kwargs: Any,
    ):
        super().__init__(
            name,
            **kwargs,
        )
        self.mode = mode

    def run(self, context: PipelineContext) -> Output:
        keyword = context.search_field("keyword")
        if self.mode == KeywordsProcessingMode.REDUCE:
            LemmaFinderReducer(keyword).process()
        elif self.mode == KeywordsProcessingMode.ENRICH:
            LemmaFinderEnricher(keyword).process()
        else:
            raise ValueError(f"Wrong mode given: {self.mode}!")
        return dict(context.payload)
