Skip to content

Commit

Permalink
fix tests + unify formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
micpst committed Jan 15, 2025
1 parent c804f08 commit f4d9385
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 119 deletions.
33 changes: 33 additions & 0 deletions packages/ragbits-core/src/ragbits/core/audit/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,36 @@ def trace(self, name: str, **inputs: Any) -> Iterator[SimpleNamespace]: # noqa:

span = self._spans.get().pop()
self.stop(outputs=vars(outputs), current_span=span)


def format_attributes(data: dict, prefix: str | None = None) -> dict:
"""
Format attributes for CLI.
Args:
data: The data to format.
prefix: The prefix to use for the keys.
Returns:
The formatted attributes.
"""
flattened = {}

for key, value in data.items():
current_key = f"{prefix}.{key}" if prefix else key

if isinstance(value, dict):
flattened.update(format_attributes(value, current_key))
elif isinstance(value, list | tuple):
flattened[current_key] = repr(
[
item if isinstance(item, str | float | int | bool) else repr(item)
for item in value # type: ignore
]
)
elif isinstance(value, str | float | int | bool):
flattened[current_key] = value
else:
flattened[current_key] = repr(value)

return flattened
45 changes: 6 additions & 39 deletions packages/ragbits-core/src/ragbits/core/audit/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from rich.live import Live
from rich.tree import Tree

from ragbits.core.audit import TraceHandler
from ragbits.core.audit.base import TraceHandler, format_attributes


class SpanStatus(Enum):
Expand Down Expand Up @@ -50,7 +50,7 @@ def __init__(self, name: str, attributes: dict, parent: "CLISpan | None" = None)
self.end_time: float | None = None
self.status = SpanStatus.STARTED
self.tree = Tree("")
if self.parent:
if self.parent is not None:
self.parent.tree.add(self.tree)

def update(self) -> None:
Expand All @@ -68,7 +68,7 @@ def update(self) -> None:
# TODO: Remove truncating after implementing better CLI formatting.
attrs = [
f"[{PrintColor.PURPLE}]{k}:[/{PrintColor.PURPLE}] "
f"[{PrintColor.GRAY}]{v[:120] + ' (...)' if len(v) > 120 else v}[/{PrintColor.GRAY}]" # noqa: PLR2004
f"[{PrintColor.GRAY}]{str(v)[:120] + ' (...)' if len(str(v)) > 120 else v}[/{PrintColor.GRAY}]" # noqa: PLR2004
for k, v in self.attributes.items()
]
self.tree.label = f"{name}\n{chr(10).join(attrs)}" if attrs else name
Expand Down Expand Up @@ -104,7 +104,7 @@ def start(self, name: str, inputs: dict, current_span: CLISpan | None = None) ->
Returns:
The updated current trace span.
"""
attributes = _format_attributes(inputs, prefix="inputs")
attributes = format_attributes(inputs, prefix="inputs")
span = CLISpan(
name=name,
attributes=attributes,
Expand All @@ -128,7 +128,7 @@ def stop(self, outputs: dict, current_span: CLISpan) -> None:
outputs: The output data.
current_span: The current trace span.
"""
attributes = _format_attributes(outputs, prefix="outputs")
attributes = format_attributes(outputs, prefix="outputs")
current_span.attributes.update(attributes)
current_span.status = SpanStatus.COMPLETED
current_span.end()
Expand All @@ -147,7 +147,7 @@ def error(self, error: Exception, current_span: CLISpan) -> None:
error: The error that occurred.
current_span: The current trace span.
"""
attributes = _format_attributes({"message": str(error), **vars(error)}, prefix="error")
attributes = format_attributes({"message": str(error), **vars(error)}, prefix="error")
current_span.attributes.update(attributes)
current_span.status = SpanStatus.ERROR
current_span.end()
Expand All @@ -157,36 +157,3 @@ def error(self, error: Exception, current_span: CLISpan) -> None:

if current_span.parent is None:
self.live.stop()


def _format_attributes(data: dict, prefix: str | None = None) -> dict:
"""
Format attributes for CLI.
Args:
data: The data to format.
prefix: The prefix to use for the keys.
Returns:
The formatted attributes.
"""
flattened = {}

for key, value in data.items():
current_key = f"{prefix}.{key}" if prefix else key

if isinstance(value, dict):
flattened.update(_format_attributes(value, current_key))
elif isinstance(value, list | tuple):
flattened[current_key] = repr(
[
item if isinstance(item, str | float | int | bool) else repr(item)
for item in value # type: ignore
]
)
elif isinstance(value, str | float | int | bool):
flattened[current_key] = str(value)
else:
flattened[current_key] = repr(value)

return flattened
40 changes: 4 additions & 36 deletions packages/ragbits-core/src/ragbits/core/audit/otel.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from opentelemetry import trace
from opentelemetry.trace import Span, StatusCode, TracerProvider
from opentelemetry.util.types import AttributeValue

from ragbits.core.audit.base import TraceHandler
from ragbits.core.audit.base import TraceHandler, format_attributes


class OtelTraceHandler(TraceHandler[Span]):
Expand Down Expand Up @@ -35,7 +34,7 @@ def start(self, name: str, inputs: dict, current_span: Span | None = None) -> Sp
context = trace.set_span_in_context(current_span) if current_span else None

with self._tracer.start_as_current_span(name, context=context, end_on_exit=False) as span:
attributes = _format_attributes(inputs, prefix="inputs")
attributes = format_attributes(inputs, prefix="inputs")
span.set_attributes(attributes)

return span
Expand All @@ -48,7 +47,7 @@ def stop(self, outputs: dict, current_span: Span) -> None: # noqa: PLR6301
outputs: The output data.
current_span: The current trace span.
"""
attributes = _format_attributes(outputs, prefix="outputs")
attributes = format_attributes(outputs, prefix="outputs")
current_span.set_attributes(attributes)
current_span.set_status(StatusCode.OK)
current_span.end()
Expand All @@ -61,38 +60,7 @@ def error(self, error: Exception, current_span: Span) -> None: # noqa: PLR6301
error: The error that occurred.
current_span: The current trace span.
"""
attributes = _format_attributes(vars(error), prefix="error")
attributes = format_attributes({"message": str(error), **vars(error)}, prefix="error")
current_span.set_attributes(attributes)
current_span.set_status(StatusCode.ERROR)
current_span.end()


def _format_attributes(data: dict, prefix: str | None = None) -> dict[str, AttributeValue]:
"""
Format attributes for OpenTelemetry.
Args:
data: The data to format.
prefix: The prefix to use for the keys.
Returns:
The formatted attributes.
"""
flattened = {}

for key, value in data.items():
current_key = f"{prefix}.{key}" if prefix else key

if isinstance(value, dict):
flattened.update(_format_attributes(value, current_key))
elif isinstance(value, list | tuple):
flattened[current_key] = [
item if isinstance(item, str | float | int | bool) else repr(item)
for item in value # type: ignore
]
elif isinstance(value, str | float | int | bool):
flattened[current_key] = value
else:
flattened[current_key] = repr(value)

return flattened
40 changes: 39 additions & 1 deletion packages/ragbits-core/tests/unit/audit/test_cli.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from unittest.mock import MagicMock, patch

from ragbits.core.audit.cli import CLISpan, CLITraceHandler, SpanStatus
from ragbits.core.audit.cli import CLISpan, CLITraceHandler, PrintColor, SpanStatus

TEST_NAME_1 = "process_1"
TEST_NAME_2 = "process_2"
Expand Down Expand Up @@ -48,6 +48,44 @@ def test_cli_span_end_method() -> None:
assert test_instance.status == SpanStatus.STARTED


def test_cli_span_to_tree() -> None:
with patch("time.perf_counter", side_effect=[11.0, 22.0, 33.0]):
parent_instance = CLISpan(name=TEST_NAME_1, attributes=TEST_INPUT)
second_instance = CLISpan(name=TEST_NAME_2, attributes={}, parent=parent_instance)
third_instance = CLISpan(name="process_3", attributes={}, parent=second_instance)
third_instance.status = SpanStatus.ERROR

parent_instance.update()
second_instance.update()
third_instance.update()

assert "process_1" in str(parent_instance.tree.label)
assert PrintColor.BLUE in str(parent_instance.tree.label)
assert "process_2" in str(second_instance.tree.label)
assert PrintColor.GRAY in str(second_instance.tree.label)
assert "process_3" in str(third_instance.tree.label)
assert PrintColor.RED in str(third_instance.tree.label)


def test_cli_trace_start() -> None:
trace_handler = CLITraceHandler()
parent_span = trace_handler.start(name=TEST_NAME_1, inputs=TEST_INPUT)

assert trace_handler.live is not None
assert trace_handler.tree is not None
assert TEST_NAME_1 in str(trace_handler.tree.label)
assert parent_span.name == TEST_NAME_1
assert parent_span.parent is None
assert parent_span.start_time is not None
assert parent_span.status == SpanStatus.STARTED

child_span = trace_handler.start(name=TEST_NAME_2, inputs={}, current_span=parent_span)
assert child_span.name == TEST_NAME_2
assert child_span.parent == parent_span
assert child_span.parent.name == TEST_NAME_1
trace_handler.live.stop()


def test__cli_trace_stop() -> None:
trace_handler = CLITraceHandler()
parent_span = trace_handler.start(name=TEST_NAME_1, inputs={})
Expand Down
42 changes: 0 additions & 42 deletions packages/ragbits-core/tests/unit/audit/test_otel.py

This file was deleted.

40 changes: 39 additions & 1 deletion packages/ragbits-core/tests/unit/audit/test_trace.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import asyncio
from collections.abc import Callable
from datetime import datetime
from unittest.mock import MagicMock

import pytest

from ragbits.core.audit import _get_function_inputs, set_trace_handlers, trace, traceable
from ragbits.core.audit.base import TraceHandler
from ragbits.core.audit.base import TraceHandler, format_attributes


class MockTraceHandler(TraceHandler):
Expand Down Expand Up @@ -158,3 +159,40 @@ def void_function(x: int) -> None:
def test_get_function_inputs(func: Callable, args: tuple, kwargs: dict, expected: dict) -> None:
result = _get_function_inputs(func, args, kwargs)
assert result == expected


@pytest.mark.parametrize(
("input_data", "prefix", "expected"),
[
# Empty dict
({}, None, {}),
({}, "test", {}),
# Simple types
(
{"str": "value", "int": 42, "float": 3.14, "bool": True},
None,
{"str": "value", "int": 42, "float": 3.14, "bool": True},
),
# With prefix
({"str": "value", "int": 42}, "prefix", {"prefix.str": "value", "prefix.int": 42}),
# Nested dict
({"nested": {"key1": "value1", "key2": 42}}, None, {"nested.key1": "value1", "nested.key2": 42}),
# Lists and tuples
({"list": [1, 2, 3], "tuple": ("a", "b", "c")}, None, {"list": "[1, 2, 3]", "tuple": "['a', 'b', 'c']"}),
# Complex objects in lists
(
{"objects": [{"a": 1}, datetime(2023, 1, 1)]},
None,
{"objects": "[\"{'a': 1}\", 'datetime.datetime(2023, 1, 1, 0, 0)']"},
),
# Mixed nested structure
(
{"level1": {"level2": {"string": "value", "list": [1, {"x": "y"}]}}},
"test",
{"test.level1.level2.string": "value", "test.level1.level2.list": "[1, \"{'x': 'y'}\"]"},
),
],
)
def test_format_attributes(input_data: dict, prefix: str, expected: dict) -> None:
result = format_attributes(input_data, prefix)
assert result == expected

0 comments on commit f4d9385

Please sign in to comment.