from typing import Any, Dict, Union

from apadata.pipelines.pipeline_context import PipelineContext
from apadata.pipelines.pipeline_module import PipelineModule
from apadata.text_processors.text_embedder_processor import TextEmbedderProcessor

Input = Union[str, Dict[str, Any]]
Output = Union[str, Dict[str, Any]]


class TextEmbedderProcessorModule(PipelineModule[Input, Output]):
    """Class for instantiating TextEmbedderProcessor object"""

    def __init__(self, name: str = "text_embedder", **kwargs: Any):
        super().__init__(name, **kwargs)

    def run(self, context: PipelineContext) -> Output:
        payload: Union[str, Dict[str, Any]] = context.payload
        text = super().get_text_from_payload(payload)
        text_embedder = TextEmbedderProcessor(text=text)
        embedding = text_embedder.process()
        context.add("embedding", embedding)
        return payload
