from dataclasses import dataclass
from unittest.mock import patch

from apadata.models import Summary
from apadata.pipelines.pipeline_context import PipelineContext
from apadata.tasks import extract_summary_task


@dataclass
class TestResult:
    result: str


@patch("apadata.tasks.extract_summary_task.SummariesPipeline")
def test_extract_summary_task(mock_summaries_pipeline):
    mock_summaries_pipeline.return_value.run.return_value = TestResult(
        result="skibidi bop"
    )
    test_company = "accenture"
    test_keyword = "Audit"
    test_domain = "accenture.com"
    for recalculate in [False, True]:
        result = extract_summary_task(
            company=test_company,
            keyword=test_keyword,
            domain=test_domain,
            recalculate=recalculate,
        )
        assert result == "skibidi bop"
    assert (
        Summary.objects.get(company=test_company, keyword=test_keyword).text == result
    )

    assert "context" in mock_summaries_pipeline.mock_calls[0].kwargs
    assert isinstance(
        mock_summaries_pipeline.mock_calls[0].kwargs["context"], PipelineContext
    )
    assert mock_summaries_pipeline.mock_calls[0].kwargs["context"].payload == {
        "company_name": test_company,
        "keyword": test_keyword,
        "domain": test_domain,
    }
