from typing import Any, Dict, List, Optional

from apadata.chatgpt.chat_gpt import ChatGpt
from apadata.pipelines.pipeline_context import PipelineContext
from apadata.pipelines.pipeline_module import PipelineModule

Input = Dict[str, Any]
Output = str


def query0(keyword: str, categories: Optional[List[str]]) -> str:
    query = (
        f"Out of this categories list {categories} if you were to associate"
        f"only one category to the keyword {keyword} which one would it be?"
        f"Just give the category as the answer, nothing more."
    )
    return query


class KeywordCategoryModule(PipelineModule[Input, Output]):
    """Class for instantiating ChatGpt for a use case of categorizing keywords"""

    def __init__(
        self,
        name: str = "keyword_category",
        **kwargs: Any,
    ):
        super().__init__(name, **kwargs)

    @staticmethod
    def make_query(
        keyword: str = "",
        categories: Optional[List[str]] = None,
    ) -> str:
        return query0(keyword, categories)
        # 0 can be maybe changed with prompt_index%len(query_options)

    def run(self, context: PipelineContext) -> Output:
        """
        Function that runs the ChatGPT module

        @param context: Dict[str, Any] - it should contain company_name:str,
        keyword:str, num_words:int, web_content:str
        @return: output_messages: List[Dict[str, str]] - list of summaries
        """

        if not isinstance(context.payload, dict):
            raise ValueError("Payload should have been a dictionary")

        query = KeywordCategoryModule.make_query(
            keyword=str(context.search_field("keyword", "")),
            categories=context.search_field("categories", []),
        )
        input_messages = [{"role": "user", "content": query}]

        return ChatGpt().execute_call(context, input_messages)
