from typing import List, Set, Union

from apadata.api.clearbit.clearbit_api import ClearbitAPI
from apadata.models import ExternalIndustry

from ..spacy.constants import LangCode
from ..utils import flatten
from . import LangDetectTextProcessor, TextProcessor


class TargetIndustriesExtractor(TextProcessor):
    """
    Receives the name of a company or its domain and retrieves a list of its target
    industries

    Parameters
    ----------
    api_instance : ClearbitAPI
        Instance of the Clearbit API which allows for calls to the Clearbit API
    """

    def __init__(
        self,
        text: str,
        api_instance: ClearbitAPI = ClearbitAPI(),
    ):
        super().__init__(text=text)
        self.api_instance = api_instance
        self.industries: List[ExternalIndustry] = ExternalIndustry.load_all()

    def process(self) -> List[str]:
        language = LangDetectTextProcessor(self.text).process()
        lang_code = LangCode(language)
        return self.extract_target_industries_mentions(
            web_content=self.text, language=lang_code
        )

    def extract_target_industries_api(self, company_names: List[str]) -> List[str]:
        if not self.api_instance:
            return []

        companies_domains = [
            self.api_instance.suggest(company_name, count=1)
            for company_name in company_names
        ]
        companies_domains_set = {
            company_domain[0]
            for company_domain in companies_domains
            if len(company_domain)
        }

        return list(
            {self.api_instance.industry(domain) for domain in companies_domains_set}
        )

    def includes_phrases(
        self, web_content: Union[str, Set[str]], phrases: Set[str]
    ) -> bool:
        for phrase in phrases:
            words = [w for w in phrase.split(" ") if w not in ["&", "and"]]
            has_phrase = True
            for word in words:
                if word.lower() not in web_content:
                    has_phrase = False
                    break
            if has_phrase:
                return True
        return False

    def extract_target_industries_mentions(
        self, web_content: str, language: str
    ) -> (List)[str]:
        mentioned_target_industries = set()
        web_content_set = web_content.lower()

        for industry in self.industries:
            searchable_industries = set(
                flatten(list(industry.searchable_entities(language).values()))
            )
            searchable_industries = {
                industry.lower() for industry in searchable_industries
            }
            if self.includes_phrases(web_content_set, searchable_industries):
                mentioned_target_industries.add(industry.external_id)
                continue
        return list(mentioned_target_industries)
