from typing import Any, Optional

import openai
import torch
from sentence_transformers import SentenceTransformer

from apadata.constants import (
    OPEN_AI_VERSION,
    OPENAI_API_BASE,
    OPENAI_API_KEY,
    OPENAI_API_TYPE,
)

from ..spacy.constants import TextEmbeddingMethods
from ..utils import flatten
from ..vectordb.vectordb_api import VectorDBAPI
from .text_processor import TextProcessor


def set_openai_api():
    openai.api_base = OPENAI_API_BASE
    openai.api_key = OPENAI_API_KEY
    openai.api_type = OPENAI_API_TYPE
    openai.api_version = OPEN_AI_VERSION


def load_model():
    model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
    if torch.cuda.is_available():
        model.cuda()
    return model


class TextEmbedderProcessor(TextProcessor):
    """
    This class will embed a text and return its embedding
    """

    VECTOR_DB_INSTANCE = None
    MODEL: Optional[SentenceTransformer] = None

    def __init__(self, text: str, method: str = TextEmbeddingMethods.ST):
        # the default method will be replaced with Ada in a following PR, but sentence
        # transformers was kept for now in order to ensure the code does not break in
        # multiple parts of the codebase
        super().__init__(text)
        self.method = method
        if self.method == TextEmbeddingMethods.ADA:
            set_openai_api()

    def process(self) -> Any:
        if self.method == TextEmbeddingMethods.ST:
            if not TextEmbedderProcessor.MODEL:
                TextEmbedderProcessor.MODEL = load_model()
            embedding = list(
                flatten(
                    TextEmbedderProcessor.MODEL.encode(
                        self.text, convert_to_tensor=False
                    ).astype("float")
                )
            )
            TextEmbedderProcessor.get_vector_db().insert(
                text=self.text, embedding=embedding
            )
            return embedding
        if self.method == TextEmbeddingMethods.ADA:
            response = openai.Embedding.create(
                input=self.text, model="text-embedding-ada-002"
            )
            embedding = response["data"][0]["embedding"]
            TextEmbedderProcessor.get_vector_db().insert(
                text=self.text, embedding=embedding
            )
            return embedding
        raise ValueError(
            f"Wrong text embedding method given, it must be one of "
            f"the following: {TextEmbeddingMethods.values()}"
        )

    @classmethod
    def get_vector_db(cls):
        if not TextEmbedderProcessor.VECTOR_DB_INSTANCE:
            TextEmbedderProcessor.VECTOR_DB_INSTANCE = VectorDBAPI()
        return TextEmbedderProcessor.VECTOR_DB_INSTANCE
