from typing import Any, Dict

from apadata.pipelines.pipeline_context import PipelineContext
from apadata.pipelines.pipeline_module import PipelineModule
from apadata.spacy.constants import KeywordsProcessingMode
from apadata.text_processors.keywords.enrichers.cross_translator_enricher import (
    CrossTranslatorEnricher,
)
from apadata.text_processors.keywords.reducers.cross_translator_reducer import (
    CrossTranslatorReducer,
)

Input = Dict[str, Any]
Output = Dict[str, Any]


class CrossTranslatorModule(PipelineModule[Input, Output]):
    """
    PipelineModule class for the CrossTranslator class which inherits from
    TextProcessor
    """

    def __init__(
        self,
        name: str = "cross_translator",
        mode: str = KeywordsProcessingMode.REDUCE,
        **kwargs: Any,
    ):
        super().__init__(
            name,
            **kwargs,
        )
        self.mode = mode

    def run(self, context: PipelineContext) -> Output:
        keyword = context.search_field("keyword")
        if self.mode == KeywordsProcessingMode.REDUCE:
            CrossTranslatorReducer(keyword).process()
        elif self.mode == KeywordsProcessingMode.ENRICH:
            CrossTranslatorEnricher(keyword).process()
        else:
            raise ValueError(f"Wrong mode given: {self.mode}!")
        return dict(context.payload)
