from typing import Any, Dict

from apadata.pipelines.pipeline_context import PipelineContext
from apadata.pipelines.pipeline_module import PipelineModule
from apadata.spacy.constants import KeywordsProcessingMode
from apadata.text_processors.keywords.enrichers.deacronymizer_enricher import (
    DeacronymizerEnricher,
)
from apadata.text_processors.keywords.reducers.deacronymizer_reducer import (
    DeacronymizerReducer,
)

Input = Dict[str, Any]
Output = Dict[str, Any]


class DeacronymizerModule(PipelineModule[Input, Output]):
    """
    Looks for pairs of 2 words where one is an acronym of the other and removes the
    acronym
    """

    def __init__(
        self,
        name: str = "deacronymizer",
        mode: str = KeywordsProcessingMode.REDUCE,
        **kwargs: Any,
    ):
        super().__init__(
            name,
            **kwargs,
        )
        self.mode = mode

    def run(self, context: PipelineContext) -> Output:
        keyword = context.search_field("keyword")
        if self.mode == KeywordsProcessingMode.REDUCE:
            DeacronymizerReducer(keyword).process()
        elif self.mode == KeywordsProcessingMode.ENRICH:
            DeacronymizerEnricher(keyword).process()
        else:
            raise ValueError(f"Wrong mode given: {self.mode}!")
        return dict(context.payload)
