from typing import Any, Dict, List, Optional, Tuple

from apadata.spacy import Spacy, SpacyConfiguration

from .target_industries_utils import NER_LABELS_CATEGORIES_DICT
from .text_processor import TextProcessor


class SpacyTextProcessor(TextProcessor):
    """
    Obtains a spacy document

    Parameters
    ----------
    text : str
        The text that will be passed to the spacy pipeline in order to then obtain a
        spacy document
    spacy_config : SpacyConfiguration
        Contains all the required configuration for instantiating a spacy pipeline

    """

    def __init__(self, text: str, spacy_config: SpacyConfiguration):
        super().__init__(text)
        self.text = text
        self.spacy_config = spacy_config
        self.spacy: Any = None
        self.doc: Any = None
        self.document_entities: Optional[Dict[str, List[str]]] = None

    def process(self) -> Tuple[Any, Dict[str, List[str]]]:
        self.obtain_doc()
        self.document_entities = {
            ent_type: self.get_document_entities(ent_type)
            for ent_type in ["company", "person", "space", "time", "product", "numeric"]
        }
        return (self.doc, self.document_entities)

    def obtain_doc(self) -> Any:
        self.spacy = Spacy(self.spacy_config)

        if not self.spacy.nlp:
            raise ValueError("nlp atribute not initialized!")

        self.doc = self.spacy.nlp(self.text)
        return self.doc

    def get_document_entities(self, entity_type: str = "company") -> List[str]:
        return list(
            {
                ent.text
                for ent in self.doc.ents
                if ent.label_ in NER_LABELS_CATEGORIES_DICT[entity_type]
            }
        )

    def get_company_names(self):
        return self.get_document_entities("company")

    def get_people_names(self):
        return self.get_document_entities("person")

    def get_place_names(self):
        return self.get_document_entities("space")

    def get_time_period_names(self):
        return self.get_document_entities("time")

    def get_product_names(self):
        return self.get_document_entities("product")

    def get_numbers(self):
        return self.get_document_entities("numeric")
