from typing import List

from apadata.pipelines.pipeline_context import PipelineContext
from apadata.spacy import SpacyConfiguration
from apadata.spacy.constants import (
    LangCode,
    SpacyGenre,
    SpacyPipeline,
    SpacyWordRootOption,
)
from apadata.text_processors import SpacyTextProcessor

from ..target_industries_mentions_module import TargetIndustriesMentionsModule


def test_target_industries_mentions_module():
    target_industries: List[str] = ["Software & Internet", "High Tech"]
    pipeline_genre = SpacyGenre(SpacyGenre.CORE_WEB)
    pipeline_type = SpacyPipeline(SpacyPipeline.SM_PIPELINE)
    word_root_pipeline = SpacyWordRootOption(SpacyWordRootOption.LEMMATIZER)
    lang = LangCode("en")
    text = (
        "She studied computer science. She is a software engineer. She works for Google"
        "in Bucharest. Google is a high tech company."
    )
    spacy_config = SpacyConfiguration(
        lang_code=lang,
        pipeline_type=pipeline_type,
        pipeline_genre=pipeline_genre,
        word_root_pipeline=word_root_pipeline,
    )
    spacy_client = SpacyTextProcessor(text=text, spacy_config=spacy_config)
    spacy_client.obtain_doc()
    company_names = spacy_client.get_company_names()
    payload = {"company_names": company_names, "text": text, "language": "en"}
    context = PipelineContext(payload=payload)
    target_module: TargetIndustriesMentionsModule = TargetIndustriesMentionsModule()
    target_module.run(context)
    pred_target_industries = context.get("target_industries")
    assert sorted(pred_target_industries) == sorted(target_industries)
