from typing import Any, Dict, List

# Note: The openai-python library support for Azure OpenAI is in preview.
import openai
from openai.error import ServiceUnavailableError

from apadata.constants import (
    OPEN_AI_VERSION,
    OPENAI_API_BASE,
    OPENAI_API_KEY,
    OPENAI_API_TYPE,
)
from apadata.pipelines.pipeline_context import PipelineContext
from apadata.utils import Singleton


def set_openai_api():
    openai.api_base = OPENAI_API_BASE
    openai.api_key = OPENAI_API_KEY
    openai.api_type = OPENAI_API_TYPE
    openai.api_version = OPEN_AI_VERSION


class ChatGpt(metaclass=Singleton):
    """ChatGpt communication class"""

    def __init__(self):
        set_openai_api()

    @staticmethod
    def get_output_from_response(response: Dict[str, Any]) -> List[Dict[str, Any]]:
        output_messages = []
        choices = response["choices"]
        for choice in choices:
            message = choice["message"]
            content = message["content"]
            role = message["role"]
            output_message = {"content": content, "role": role}
            output_messages.append(output_message)
        return output_messages

    @staticmethod
    def default_configuration():
        return {
            "engine": "gpt-35-turbo",
            "temperature": 1.0,
            "max_tokens": 200,
            "top_p": 0.95,
            "frequency_penalty": 0.0,
            "presence_penalty": 0.0,
        }

    def request(self, messages: List[Dict[Any, Any]], **kwargs: Any) -> Any:
        try:
            return openai.ChatCompletion.create(
                messages=messages, **{**self.default_configuration(), **kwargs}
            )
        except ServiceUnavailableError:
            return None

    def execute_call(
        self, context: PipelineContext, input_messages: List[Dict[str, str]]
    ) -> str:
        parameter_dict = self.default_configuration()
        for parameter, value in parameter_dict.items():
            parameter_dict[parameter] = context.search_field(
                field=parameter, default=value
            )
        parameter_dict["messages"] = input_messages

        response = self.request(**parameter_dict)
        output_messages = self.get_output_from_response(response)
        return str(output_messages[0]["content"])
