from unittest.mock import patch

import pytest

from apadata.modules.summary_module import (
    SummaryModule,
    query1,
    query1a,
    query1b,
    query1c,
    query2,
    query3,
    query4,
)
from apadata.pipelines.pipeline_context import PipelineContext


@patch("apadata.chatgpt.chat_gpt.openai")
def test_summary_module(mock_openai):
    mock_response = {
        "choices": [
            {"message": {"role": "mock-role", "content": "this is a mock message"}}
        ]
    }
    mock_openai.ChatCompletion.create.return_value = mock_response

    payload = {
        "company_name": "apadua",
        "keyword": "gg",
        "text": "we are apadua",
    }

    context = PipelineContext(payload=payload)
    context.add("temperature", 2.0)
    context.add("prompt_index", 0)
    gpt_module: SummaryModule = SummaryModule()
    module_output = gpt_module.run(context)
    assert module_output == "this is a mock message"
    assert isinstance(query1("apadua", "gg", 100, "we are apadua"), str)
    assert isinstance(query1a("apadua", "gg", 100, "we are apadua"), str)
    assert isinstance(query1b("apadua", "gg", 100, "we are apadua"), str)
    assert isinstance(query1c("apadua", "gg", 100, "we are apadua"), str)
    assert isinstance(query2("apadua", "gg", 100, "we are apadua"), str)
    assert isinstance(query3("apadua", "gg", 100, "we are apadua"), str)
    assert isinstance(query4("apadua", "gg", 100, "we are apadua"), str)

    bad_payload = 3.14
    context = PipelineContext(payload=bad_payload)
    bad_gpt_module: SummaryModule = SummaryModule()
    with pytest.raises(ValueError):
        bad_gpt_module.run(context)
