from typing import Any, Dict

from apadata.elasticsearch_client.elasticsearch_client import ElasticsearchClient
from apadata.pipelines.pipeline_context import PipelineContext
from apadata.pipelines.pipeline_module import PipelineModule
from apadata.spacy.constants import CONTENTS_SEPARATOR, URLS_SEPARATOR
from apadata.utils import flatten

Input = Dict[str, Any]
Output = Dict[str, Any]


class ElasticsearchClientModule(PipelineModule[Input, Output]):
    """Class for instantiating ElasticsearchClient object"""

    def __init__(
        self,
        name: str = "elastic",
        **kwargs: Any,
    ):
        super().__init__(name, **kwargs)

    def run(self, context: PipelineContext) -> Output:
        elasticsearch_client = ElasticsearchClient()
        domain = context.search_field("domain")
        keyword = context.search_field("keyword", "")
        size = context.search_field("size", 1)
        urls = context.search_field("urls", [])
        categories_to_urls = (
            context.search_field("categories_to_urls", []) if not urls else None
        )
        urls = (
            flatten(list(categories_to_urls.values())) if categories_to_urls else None
        )

        if not urls:
            urls = [domain]
        context.add("urls", urls)
        url_hits = [
            elasticsearch_client.search(
                domain=domain,
                url=url,
                keyword=keyword,
                size=size,
            )
            for url in urls
        ]
        contents = URLS_SEPARATOR.join(
            [
                CONTENTS_SEPARATOR.join(elasticsearch_client.get_web_contents(url_hit))
                for url_hit in url_hits
            ]
        )
        return {"text": contents}
