from dataclasses import dataclass
from unittest.mock import patch

from apadata.models import Keyword
from apadata.pipelines.pipeline_context import PipelineContext
from apadata.tasks import classify_keywords_task


@dataclass
class TestResult:
    result: str


@patch("apadata.tasks.classify_keywords_task.ClassifyKeywordsPipeline")
def test_task(mock_classify_keywords_pipeline):
    mock_classify_keywords_pipeline.return_value.run.return_value = TestResult(
        result="service"
    )
    test_keyword = "Consulting"
    test_categories = ["topic", "tech", "function", "service", "industry"]
    for recalculate in [True, False]:
        result = classify_keywords_task(
            keyword=test_keyword,
            categories=test_categories,
            recalculate=recalculate,
        )
        assert result == "service"
    assert Keyword.objects.get(name=test_keyword).category == result

    assert "context" in mock_classify_keywords_pipeline.mock_calls[0].kwargs
    assert isinstance(
        mock_classify_keywords_pipeline.mock_calls[0].kwargs["context"], PipelineContext
    )
    assert mock_classify_keywords_pipeline.mock_calls[0].kwargs["context"].payload == {
        "keyword": test_keyword,
        "categories": ["topic", "tech", "function", "service", "industry"],
    }
