import pytest

from apadata.spacy.constants import (
    LangCode,
    SpacyGenre,
    SpacyPipeline,
    SpacyWordRootOption,
)
from apadata.spacy.spacy_configuration import SpacyConfiguration


def test_spacy_configuration():
    lang = LangCode("en")
    pipeline_genre = SpacyGenre(SpacyGenre.CORE_WEB)
    pipeline_type = SpacyPipeline(SpacyPipeline.SM_PIPELINE)

    spacy_config = SpacyConfiguration(
        lang_code=lang, pipeline_genre=pipeline_genre, pipeline_type=pipeline_type
    )
    assert spacy_config.pipeline_name == "en_core_web_sm"

    spacy_config = SpacyConfiguration(
        lang_code=lang,
        pipeline_genre=pipeline_genre,
        pipeline_type=pipeline_type,
        use_lookup_lemmatizer=True,
    )
    assert spacy_config.pipeline_name == "en_core_web_sm"

    spacy_config = SpacyConfiguration(
        lang_code=lang,
        pipeline_genre=pipeline_genre,
        pipeline_type=pipeline_type,
        use_senter_over_parser=True,
    )
    assert spacy_config.pipeline_name == "en_core_web_sm"

    spacy_config = SpacyConfiguration(
        lang_code=lang,
        pipeline_genre=pipeline_genre,
        pipeline_type=pipeline_type,
        use_default_over_trainable=True,
    )
    assert spacy_config.pipeline_name == "en_core_web_sm"

    spacy_config = SpacyConfiguration(
        lang_code=lang,
        pipeline_genre=pipeline_genre,
        pipeline_type=pipeline_type,
        word_root_pipeline=SpacyWordRootOption(SpacyWordRootOption.LEMMATIZER),
    )
    assert spacy_config.pipeline_name == "en_core_web_sm"

    with pytest.raises(ValueError):
        lang = LangCode("de")
        _ = SpacyConfiguration(
            lang_code=lang,
            pipeline_genre=pipeline_genre,
            pipeline_type=pipeline_type,
            use_lookup_lemmatizer=True,
        )

    lang = LangCode("de")
    pipeline_genre = SpacyGenre(SpacyGenre.CORE_NEWS)
    pipeline_type = SpacyPipeline(SpacyPipeline.SM_PIPELINE)
    spacy_config = SpacyConfiguration(
        lang_code=lang, pipeline_genre=pipeline_genre, pipeline_type=pipeline_type
    )
    assert spacy_config.pipeline_name == "de_core_news_sm"
