from typing import Any, Dict

from apadata.pipelines.pipeline_context import PipelineContext
from apadata.pipelines.pipeline_module import PipelineModule
from apadata.text_processors.target_industries_extractor import (
    TargetIndustriesExtractor,
)

Input = Dict[str, Any]
Output = Dict[str, Any]


class TargetIndustriesMentionsModule(PipelineModule[Input, Output]):
    """Does some preprocessing in order to retrieve the end-result for the target
    industries module"""

    def __init__(self, name: str = "mentions", **kwargs: Any):
        super().__init__(name, **kwargs)

    def run(self, context: PipelineContext) -> Output:
        web_content = context.search_field("text")
        language = context.search_field("language")

        urls = context.search_field("urls", [])
        web_content_with_urls = web_content + " ".join(
            [url.replace("-", " ") for url in urls]
        )

        target_industries_mentions_extractor = TargetIndustriesExtractor(
            text=web_content_with_urls
        )
        target_industries = (
            target_industries_mentions_extractor.extract_target_industries_mentions(
                web_content=web_content_with_urls, language=language
            )
        )

        context.merge("target_industries", target_industries, remove_duplicates=True)
        return dict(context.payload)
