from typing import Any

from apadata.modules.text_embedder_processor_module import TextEmbedderProcessorModule
from apadata.pipelines.pipeline import Pipeline
from apadata.pipelines.pipeline_context import PipelineContext
from apadata.utils import flatten


class TextEmbedderProcessorPipeline(Pipeline):
    """
    Class that performs the embedding of a text

    Parameters
    ----------
    context: PipelineContext
        Context of the pipeline which receives and passes along data across a
        pipeline enriching it with information from several modules
    """

    def __init__(
        self,
        context: PipelineContext,
    ):
        super().__init__(
            context=context,
            modules=[TextEmbedderProcessorModule()],
        )

    def post_process(self, context: PipelineContext) -> Any:
        return {"embedding": list(flatten(context.get("embedding")))}
