from typing import Optional

import json
from dataclasses import dataclass
from unittest.mock import patch

from apadata.loaders.text_loader import TextLoader
from apadata.pipelines.pipeline_context import PipelineContext

from ..sitemap_api_module import SitemapAPIModule


@dataclass
class TestResponse:
    content: bytes


def mocked_get_response(endpoint: str) -> Optional[TestResponse]:
    if endpoint == "sitemap.xml":
        return TestResponse(
            content=bytes(
                TextLoader(
                    filepath="api/sitemap_api/tests/mocked_sitemap_xml_response.txt"
                ).load(),
                "utf-8",
            )
        )
    elif endpoint == "sitemap_index.xml":
        return TestResponse(
            content=bytes(
                TextLoader(
                    filepath="api/sitemap_api/tests/"
                    + "mocked_sitemap_index_xml_response.txt"
                ).load(),
                "utf-8",
            )
        )
    return None


@patch("apadata.api.sitemap_api.sitemap_api.SitemapAPI.get")
def test_sitemap_api_module(mock_get, snapshot):
    mock_get.side_effect = mocked_get_response
    payload = {"domain": "https://kpmg.us"}
    context = PipelineContext(payload=payload)
    sitemap_module: SitemapAPIModule = SitemapAPIModule()
    sitemap_module.run(context)

    snapshot.assert_match(json.dumps(context.get("sitemap")), "sitemap-api-sitemap")

    snapshot.assert_match(
        json.dumps(context.get("categories_to_urls")),
        "sitemap-api-category-to-urls-success",
    )
