from typing import Any, List

import pytest

from apadata.pipelines.pipeline import Pipeline, PipelineError
from apadata.pipelines.pipeline_context import (
    ExecutionState,
    ExecutionStatus,
    PipelineContext,
)
from apadata.pipelines.pipeline_module import PipelineModule


class TestModule1(PipelineModule[str, str]):
    def __init__(self, name: str = "testmodule1", **kwargs: Any):
        super().__init__(name, **kwargs)

    def should_skip(self, context: PipelineContext) -> bool:
        return bool(context.get("skip1"))

    def run(self, context: PipelineContext) -> str:
        if context.get("error1"):
            raise RuntimeError("Error from module 1")

        context.add("module1", "mod1")
        context.add("cloned_input1", context.payload)
        return str(context.payload) + "test1"


class TestModule2(PipelineModule[str, str]):
    def __init__(self, name: str = "testmodule2", **kwargs: Any):
        super().__init__(name, **kwargs)

    def should_skip(self, context: PipelineContext) -> bool:
        return bool(context.get("skip2"))

    def run(self, context: PipelineContext) -> str:
        context.add("module2", "mod2")
        context.add("cloned_input2", context.payload)
        return str(context.payload) + "test2"


@pytest.fixture
def pipeline():
    context: PipelineContext = PipelineContext("payload")
    pipeline = Pipeline(context)
    yield pipeline


def test_pipeline_base_class_one_module(pipeline):
    pipeline.add_module(TestModule1())
    result_context = pipeline.run()
    assert result_context.get("module1") == "mod1"
    assert result_context.get("cloned_input1") == "payload"
    assert result_context.payload == "payload"
    assert result_context.get("module2") is None
    assert result_context.get("cloned_input2") is None

    assert len(result_context.states) == 1
    assert result_context.states[0].status == ExecutionStatus.SUCCESS


def test_pipeline_base_class_two_modules(pipeline):
    pipeline.add_module(TestModule1()).add_module(TestModule2())
    result_context = pipeline.run()

    assert result_context.get("module1") == "mod1"
    assert result_context.get("cloned_input1") == "payload"
    assert result_context.payload == "payload"
    assert result_context.get("module2") == "mod2"
    assert result_context.get("cloned_input2") == "payloadtest1"

    assert len(result_context.states) == 2
    assert result_context.states[0].status == ExecutionStatus.SUCCESS
    assert result_context.states[0].context.payload == "payload"
    assert result_context.states[0].context.result == "payloadtest1"

    # Output of first module is input of second module
    assert result_context.states[1].status == ExecutionStatus.SUCCESS
    assert result_context.states[1].context.payload == "payloadtest1"
    assert result_context.states[1].context.result == "payloadtest1test2"


def test_pipeline_base_class_two_modules_skip_first(pipeline):
    pipeline.add_module(TestModule1()).add_module(TestModule2())
    pipeline.context.add("skip1", True)
    result_context = pipeline.run()
    assert result_context.get("module1") is None
    assert result_context.get("cloned_input1") is None
    assert result_context.payload == "payload"
    assert result_context.get("module2") == "mod2"
    assert result_context.get("cloned_input2") == "payload"

    assert len(result_context.states) == 2
    assert result_context.states[0].status == ExecutionStatus.SKIPPED
    assert result_context.states[0].context.payload == "payload"
    assert result_context.states[0].context.result is None

    # Output of first module is input of second module
    assert result_context.states[1].status == ExecutionStatus.SUCCESS
    assert result_context.states[1].context.payload == "payload"
    assert result_context.states[1].context.result == "payloadtest2"


def test_pipeline_error_on_module_one(pipeline):
    context = PipelineContext("payload")
    context.add("error1", True)
    pipeline = Pipeline(context)
    pipeline.add_module(TestModule1()).add_module(TestModule2())

    with pytest.raises(PipelineError) as error:
        pipeline.run()

    state_opt = error.value.context.get_last_module_execution()
    assert state_opt is not None
    state: ExecutionState = state_opt
    assert state.status == ExecutionStatus.FAILURE
    assert str(state.error) == "Error from module 1"
    assert state.context.payload == "payload"
    assert state.context.result is None
    assert len(error.value.context.states) == 1


def test_pipeline_setting_context_and_modules_not_allowed(pipeline):
    with pytest.raises(SyntaxError):
        pipeline.context = None
    with pytest.raises(SyntaxError):
        pipeline.modules = None
    with pytest.raises(SyntaxError):
        pipeline.context.payload = None
    with pytest.raises(SyntaxError):
        pipeline.context.result = None


def test_pipeline_adds_module_to_position(pipeline):
    pipeline.add_module(TestModule1())
    pipeline.add_module(TestModule2(), position=0)

    assert pipeline.modules[0].__class__.__name__ == "TestModule2"
    assert pipeline.modules[1].__class__.__name__ == "TestModule1"


def test_pipeline_pre_process():
    class PreProcessPipeline(Pipeline):
        def pre_process(self, context: PipelineContext) -> str:
            context.add("pre_process", True)
            return "pre_processed_" + str(context.payload)

    pipe = PreProcessPipeline(PipelineContext("payload"))
    pipe.add_module(TestModule1())
    result_context = pipe.run()

    assert result_context.get("pre_process") is True
    assert result_context.get("module1") == "mod1"
    assert result_context.get("cloned_input1") == "pre_processed_payload"
    assert result_context.result == "pre_processed_payloadtest1"
    assert result_context.payload == "payload"


def test_pipeline_post_process():
    class PostProcessPipeline(Pipeline):
        def post_process(self, context: PipelineContext) -> List[str]:
            context.add("post_process", True)
            return [context.result, "post_process_result"]

    pipe = PostProcessPipeline(PipelineContext("payload"))
    pipe.add_module(TestModule1())
    result_context = pipe.run()

    assert result_context.get("post_process") is True
    assert result_context.get("module1") == "mod1"
    assert result_context.get("cloned_input1") == "payload"
    assert result_context.result == ["payloadtest1", "post_process_result"]
    assert result_context.payload == "payload"


def test_pipeline_base_class_one_module_pre_process(pipeline):
    pipeline.add_module(
        TestModule1(
            pre_process=lambda context: f"processed_{context.payload}",
        )
    )
    result_context = pipeline.run()
    assert result_context.get("module1") == "mod1"
    assert result_context.get("cloned_input1") == "processed_payload"
    assert result_context.payload == "payload"
    assert result_context.get("module2") is None
    assert result_context.get("cloned_input2") is None
    assert result_context.result == "processed_payloadtest1"
    assert len(result_context.states) == 1
    assert result_context.states[0].status == ExecutionStatus.SUCCESS


def test_pipeline_base_class_one_module_post_process(pipeline):
    pipeline.add_module(
        TestModule1(
            pre_process=lambda context: f"processed_{context.payload}",
            post_process=lambda context: f"{context.result}_post_processed",
        )
    )
    result_context = pipeline.run()
    assert result_context.get("module1") == "mod1"
    assert result_context.get("cloned_input1") == "processed_payload"
    assert result_context.payload == "payload"
    assert result_context.get("module2") is None
    assert result_context.get("cloned_input2") is None
    assert result_context.result == "processed_payloadtest1_post_processed"
    assert len(result_context.states) == 1
    assert result_context.states[0].status == ExecutionStatus.SUCCESS
