from typing import List, Optional

from django.db import models
from structlog import get_logger

from apadata.models import Keyword
from apadata.pipelines.classify_keywords_pipeline import ClassifyKeywordsPipeline
from apadata.pipelines.pipeline_context import PipelineContext

logger = get_logger(__name__)


def classify_keywords_task(
    *,
    keyword: str,
    categories: Optional[List[str]] = None,
    recalculate: bool = False,
    prompt_index: int = 0,
) -> str:
    if not categories:
        categories = ["topic", "tech", "function", "service", "industry"]
    try:
        keyword_object = Keyword.objects.get(
            name=keyword,
        )
        if not recalculate:
            return str(keyword_object.category)
    except models.ObjectDoesNotExist:
        keyword_object = Keyword(name=keyword)

    payload = {"keyword": keyword, "categories": categories}
    context = PipelineContext(payload=payload)
    context.add("prompt_index", prompt_index)
    summaries_pipeline = ClassifyKeywordsPipeline(context=context)
    pipeline_context = summaries_pipeline.run()
    result = pipeline_context.result

    keyword_object.category = result
    keyword_object.save()

    return str(result)
