from typing import Optional

from dataclasses import dataclass
from unittest.mock import patch

from apadata.loaders.text_loader import TextLoader
from apadata.pipelines.target_industries_extraction_pipeline import (
    TargetIndustriesExtractionPipeline,
)

from ...text_processors.target_industries_utils import URL_PRESELECTION_TAGS_EN
from ..pipeline_context import PipelineContext

dummy_text = [
    "KPMG web content with client Accenture line 1",
    "KPMG web content offers the service software development line 2",
    "KPMG web content ensures quality in consulting line 3",
]


@dataclass
class TestResponse:
    content: bytes


def mocked_get_response(endpoint: str) -> Optional[TestResponse]:
    if endpoint == "sitemap.xml":
        return TestResponse(
            content=bytes(
                TextLoader(
                    filepath="api/sitemap_api/tests/mocked_sitemap_xml_response.txt"
                ).load(),
                "utf-8",
            )
        )
    elif endpoint == "sitemap_index.xml":
        return TestResponse(
            content=bytes(
                TextLoader(
                    filepath="api/sitemap_api/tests/"
                    + "mocked_sitemap_index_xml_response.txt"
                ).load(),
                "utf-8",
            )
        )
    elif endpoint in URL_PRESELECTION_TAGS_EN:
        return TestResponse(
            content=bytes(
                TextLoader(
                    filepath="api/sitemap_api/tests/"
                    + "mocked_sitemap_index_xml_response.txt"
                ).load(),
                "utf-8",
            )
        )

    return None


class MockESClient:
    """Mocked ES client base class to be used with context managers"""

    @staticmethod
    def search(domain, keyword, size, url):
        return []

    @staticmethod
    def get_web_contents(hits):
        return dummy_text


@patch("apadata.modules.elasticsearch_client_module.ElasticsearchClient")
@patch("apadata.api.sitemap_api.sitemap_api.SitemapAPI.get")
@patch("apadata.api.clearbit.clearbit_api.ClearbitAPI.suggest")
@patch("apadata.api.clearbit.clearbit_api.ClearbitAPI.industry")
def test_target_industries_extraction_pipeline(
    mock_industry, mock_suggest, mock_sitemap_get, mock_elastic
):
    mock_elastic.return_value = MockESClient
    mock_suggest.return_value = ["https://kpmg.us"]
    mock_industry.return_value = ["Software & Internet", "Services"]
    mock_sitemap_get.side_effect = mocked_get_response
    target_industries = ["Software & Internet", "Services"]
    payload = {"domain": "https://kpmg.us"}
    context = PipelineContext(payload=payload)

    # Full pipeline
    target_full_pipeline = TargetIndustriesExtractionPipeline(context=context)
    context_result = target_full_pipeline.run()

    pred_target_industries = context_result.result["target_industries"]

    assert sorted(pred_target_industries) == sorted(target_industries)

    # Only mentions pipeline
    target_industries = ["Software & Internet"]
    context.add("skip", ["spacy", "clearbit"])
    target_mentions_pipeline = TargetIndustriesExtractionPipeline(context=context)
    context_result = target_mentions_pipeline.run()

    pred_target_industries = context_result.result["target_industries"]

    assert sorted(pred_target_industries) == sorted(target_industries)
