from typing import Any

from apadata.pipelines.pipeline_context import PipelineContext
from apadata.pipelines.pipeline_module import PipelineModule


class TestModuleNoSkip(PipelineModule[str, str]):
    def __init__(self, name: str = "testmodulenoskip", **kwargs: Any):
        super().__init__(name, **kwargs)

    def run(self, context: PipelineContext) -> str:
        return ""


class TestModule(PipelineModule[str, str]):
    def __init__(self, name: str = "testmodule", **kwargs: Any):
        super().__init__(name, **kwargs)

    def should_skip(self, context: PipelineContext) -> bool:
        return bool(context.get("skip"))

    def run(self, context: PipelineContext) -> str:
        context.add("test", "test")
        context.add("cloned_input", context.payload)
        return str(context.payload) + "test"


def test_module_base_class():
    context = PipelineContext("payload")
    module = TestModule()
    assert not module.should_skip(context)
    result = module.run(context)
    assert context.get("test") == "test"
    assert context.get("cloned_input") == "payload"
    assert context.payload == "payload"
    assert result == "payloadtest"


def test_module_base_class_should_skip_defaults_to_false():
    module = TestModuleNoSkip()
    assert not module.should_skip(PipelineContext("payload"))


def test_module_base_class_should_skip():
    context = PipelineContext("payload")
    context.add("skip", True)
    module = TestModule()
    assert module.should_skip(context)

    context = PipelineContext("payload")
    context.add("skip", ["testmodulenoskip"])
    module_2 = TestModuleNoSkip()
    assert module_2.should_skip(context)


def test_module_base_class_pre_process():
    context = PipelineContext("payload")
    module = TestModule(pre_process=lambda context: f"processed_{context.payload}")
    context._payload = module.pre_process(context)
    assert not module.should_skip(context)
    result = module.run(context)
    assert context.get("test") == "test"
    assert context.get("cloned_input") == "processed_payload"
    assert context.payload == "processed_payload"
    assert result == "processed_payloadtest"


def test_module_base_class_post_process():
    context = PipelineContext("payload")
    module = TestModule(
        post_process=lambda context: f"post_processed_{context.result}",
    )
    assert not module.should_skip(context)
    context._result = module.run(context)
    result = module.post_process(context)
    assert context.get("test") == "test"
    assert context.get("cloned_input") == "payload"
    assert context.payload == "payload"
    assert result == "post_processed_payloadtest"
