from typing import Any, Dict, Union

from apadata.pipelines.pipeline_context import PipelineContext
from apadata.pipelines.pipeline_module import PipelineModule
from apadata.spacy import SpacyConfiguration
from apadata.spacy.constants import (
    LangCode,
    SpacyGenre,
    SpacyPipeline,
    SpacyWordRootOption,
)
from apadata.text_processors import SpacyTextProcessor

Input = Union[str, Dict[str, Any]]
Output = Union[str, Dict[str, Any]]


class SpacyTextProcessorModule(PipelineModule[Input, Output]):
    """Class for instantiating SpacyClient object"""

    def __init__(self, name: str = "spacy", **kwargs: Any):
        super().__init__(name, **kwargs)

    def run(self, context: PipelineContext) -> Output:
        payload: Union[str, Dict[str, Any]] = context.payload
        language = None
        if isinstance(payload, dict):
            text = payload["text"]
            if "language" in payload:
                language = LangCode(payload["language"])
        elif isinstance(payload, str):
            text = payload
        else:
            raise ValueError("Payload should have been a dictionary or a string!")
        if not language:
            language = LangCode(context.get("language"))

        max_input_length = context.search_field("max_input_length", 3000000)

        pipeline_genre = SpacyGenre(SpacyGenre.CORE_WEB)
        pipeline_type = SpacyPipeline(SpacyPipeline.SM_PIPELINE)
        word_root_pipeline = SpacyWordRootOption(SpacyWordRootOption.LEMMATIZER)

        spacy_config = SpacyConfiguration(
            lang_code=language,
            pipeline_type=pipeline_type,
            pipeline_genre=pipeline_genre,
            word_root_pipeline=word_root_pipeline,
            max_input_length=max_input_length,
        )
        spacy_client = SpacyTextProcessor(text=text, spacy_config=spacy_config)
        spacy_client.obtain_doc()
        context.add("company_names", spacy_client.get_company_names())
        return payload
