from typing import Any, Dict

from apadata.api.clearbit.clearbit_api import ClearbitAPI
from apadata.api.exceptions import UnauthorizedRequestError
from apadata.pipelines.pipeline_context import PipelineContext
from apadata.pipelines.pipeline_module import PipelineModule
from apadata.utils import flatten

Input = Dict[str, Any]
Output = Dict[str, Any]


class ClearbitAPIModule(PipelineModule[Input, Output]):
    """
    It receives a list of companies names, and it uses one component of the
    Clearbit API in order to find their domains first and then another component in
    order to find the industries of the companies that have those domains.
    That is then added to the data member in the key called 'target-industries',
    it is done in this way and not via returning it as a result because it is an
    attribute that will be further enriched at the end of the pipeline.
    """

    def __init__(
        self,
        name: str = "clearbit",
        **kwargs: Any,
    ):
        super().__init__(
            name,
            **kwargs,
        )

    def run(self, context: PipelineContext) -> Output:
        api_instance = ClearbitAPI()
        company_names = context.search_field("company_names")
        try:
            domains = [
                api_instance.suggest(company_name, 1)[0]
                for company_name in company_names
            ]
        except UnauthorizedRequestError:
            domains = []
        try:
            industries = flatten(
                [api_instance.industry(domain=domain) for domain in (domains)]
            )
        except UnauthorizedRequestError:
            industries = []
        context.merge("target_industries", industries)
        return dict(context.payload)
