diff --git a/packages/ragbits-core/tests/unit/audit/test_trace.py b/packages/ragbits-core/tests/unit/audit/test_trace.py index 98a14e48..1cc1591a 100644 --- a/packages/ragbits-core/tests/unit/audit/test_trace.py +++ b/packages/ragbits-core/tests/unit/audit/test_trace.py @@ -1,8 +1,124 @@ +import asyncio from collections.abc import Callable +from unittest.mock import MagicMock import pytest -from ragbits.core.audit import _get_function_inputs +from ragbits.core.audit import _get_function_inputs, set_trace_handlers, trace, traceable +from ragbits.core.audit.base import TraceHandler + + +class MockTraceHandler(TraceHandler): + def start(self, name: str, inputs: dict) -> None: + pass + + def stop(self, outputs: dict) -> None: + pass + + def error(self, error: Exception) -> None: + pass + + +@pytest.fixture +def mock_handler() -> MockTraceHandler: + handler = MockTraceHandler() + set_trace_handlers(handler) + return handler + + +def test_trace_context_with_name(mock_handler: MockTraceHandler) -> None: + mock_handler.start = MagicMock() # type: ignore + mock_handler.stop = MagicMock() # type: ignore + mock_handler.error = MagicMock() # type: ignore + + with trace(name="test", input1="value1") as outputs: + outputs.result = "success" + + mock_handler.start.assert_called_once_with(name="test", inputs={"input1": "value1"}) + mock_handler.stop.assert_called_once_with({"result": "success"}) + mock_handler.error.assert_not_called() + + +def test_trace_context_without_name(mock_handler: MockTraceHandler) -> None: + mock_handler.start = MagicMock() # type: ignore + mock_handler.stop = MagicMock() # type: ignore + mock_handler.error = MagicMock() # type: ignore + + with trace() as outputs: + outputs.result = "success" + + mock_handler.start.assert_called_once_with(name="test_trace_context_without_name", inputs={}) + mock_handler.stop.assert_called_once_with({"result": "success"}) + mock_handler.error.assert_not_called() + + +def test_trace_context_exception(mock_handler: MockTraceHandler) -> None: + mock_handler.start = MagicMock() # type: ignore + mock_handler.stop = MagicMock() # type: ignore + mock_handler.error = MagicMock() # type: ignore + + with pytest.raises(ValueError), trace(name="test"): + raise ValueError("test error") + + mock_handler.start.assert_called_once_with(name="test", inputs={}) + mock_handler.error.assert_called_once() + mock_handler.stop.assert_not_called() + + +def test_traceable_sync(mock_handler: MockTraceHandler) -> None: + mock_handler.start = MagicMock() # type: ignore + mock_handler.stop = MagicMock() # type: ignore + mock_handler.error = MagicMock() # type: ignore + + @traceable + def sample_sync_function(a: int, b: str = "default") -> str: + return f"{a}-{b}" + + result = sample_sync_function(1, b="test") + assert result == "1-test" + + mock_handler.start.assert_called_once_with( + name="test_traceable_sync..sample_sync_function", + inputs={"a": 1, "b": "test"}, + ) + mock_handler.stop.assert_called_once_with({"returned": "1-test"}) + + +async def test_traceable_async(mock_handler: MockTraceHandler) -> None: + mock_handler.start = MagicMock() # type: ignore + mock_handler.stop = MagicMock() # type: ignore + mock_handler.error = MagicMock() # type: ignore + + @traceable + async def sample_async_function(x: int) -> int: + await asyncio.sleep(0.01) + return x * 2 + + result = await sample_async_function(5) + assert result == 10 + + mock_handler.start.assert_called_once_with( + name="test_traceable_async..sample_async_function", + inputs={"x": 5}, + ) + mock_handler.stop.assert_called_once_with({"returned": 10}) + + +def test_traceable_no_return(mock_handler: MockTraceHandler) -> None: + mock_handler.start = MagicMock() # type: ignore + mock_handler.stop = MagicMock() # type: ignore + mock_handler.error = MagicMock() # type: ignore + + @traceable + def void_function(x: int) -> None: + pass + + void_function(1) + mock_handler.start.assert_called_once_with( + name="test_traceable_no_return..void_function", + inputs={"x": 1}, + ) + mock_handler.stop.assert_called_once_with({}) @pytest.mark.parametrize(