diff --git a/docs/assets/otel_handler_jeager.png b/docs/assets/otel_handler_jeager.png new file mode 100644 index 00000000..5f00f81f Binary files /dev/null and b/docs/assets/otel_handler_jeager.png differ diff --git a/docs/how-to/trace_runs_with_otel.md b/docs/how-to/trace_runs_with_otel.md new file mode 100644 index 00000000..1be7332c --- /dev/null +++ b/docs/how-to/trace_runs_with_otel.md @@ -0,0 +1,126 @@ +# How-To: Trace runs with OpenTelemetry + +db-ally provides you a way to track execution of the query processing using +[OpenTelemetry](https://opentelemetry.io/) standard. As db-ally is a library, it only depends on the +[OpenTelemtry API](https://opentelemetry.io/docs/specs/otel/overview/#api). For projects that use db-ally, include +[OpenTelemetry SDK](https://opentelemetry.io/docs/specs/otel/overview/#sdk) or perform +[Auto Instrumentation](https://opentelemetry.io/docs/zero-code/python/). + + +## Step-by-step guide + +1. [Python OpenTelemetry SDK](https://opentelemetry-python.readthedocs.io/en/latest/sdk/index.html) must be installed: + + ```bash + pip install opentelemetry-sdk + ``` + +2. To capture the traces, you can use [Jeager](https://www.jaegertracing.io/). An open-source software for telemetry + data. The recommended option is to start with Docker. You can run: + + ```bash + docker run --network host --rm --name jeager -e COLLECTOR_ZIPKIN_HTTP_PORT=9411 jaegertracing/all-in-one + ``` + + For simplicity we are using `--network host`, however, do not use this settings in production deployments and + expose only ports that are needed. + +3. Import required OpenTelemetry SDKs and db-ally OTel Handler: + + ```python + from dbally.audit.event_handlers.otel_event_handler import OtelEventHandler + + from opentelemetry.sdk.trace import TracerProvider + from opentelemetry.sdk.resources import Resource + from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter + from opentelemetry.sdk.trace.export import BatchSpanProcessor + ``` + +4. Setup the OTel exporter in your project: + + ```python + exporeter = OTLPSpanExporter("http://localhost:4317", insecure=True) + provider = TracerProvider(resource=Resource({"service.name": "db-ally"})) + processor = BatchSpanProcessor(exporeter) + provider.add_span_processor(processor) + handler = OtelEventHandler(provider) + ``` + + Using Resource you can add a name for your service. OTLPSpanExporter is used to export telemetry data using gRPC or + HTTP to desired location. We mark it as insecure, as demo does not use TLS. To efficently send data over network, + we should use BatchSpanProcessor to batch exports of telemetry data. Finally, we setup the db-ally handler. + +5. Use handler with collection: + + ```python + df = pd.DataFrame({ + "name": ["Alice", "Bob", "Charlie", "David", "Eve"], + "city": ["New York", "Los Angeles", "Chicago", "Houston", "Phoenix"], + }) + + llm = LiteLLM(model_name="gpt-4o") + collection = dbally.create_collection("clients", llm=llm, event_handlers=[handler], nl_responder=NLResponder(llm)) + collection.add(ClientView, lambda: ClientView(df)) + ``` + +6. Ask your questions: + + ```python + result = await collection.ask("What clients are from Huston?", return_natural_response=True) + print(result) + ``` + +7. Explore your traces in observability platform (Jeager in our case): + + ![Example trace in Jeager UI](../assets/otel_handler_jeager.png) + + +## Full code example + +```python +import asyncio +import pandas as pd + +import dbally +from dbally import DataFrameBaseView +from dbally.audit.event_handlers.otel_event_handler import OtelEventHandler +from dbally.nl_responder.nl_responder import NLResponder +from dbally.views import decorators +from dbally.llms import LiteLLM + +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.resources import Resource +from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter +from opentelemetry.sdk.trace.export import BatchSpanProcessor + + +class ClientView(DataFrameBaseView): + + @decorators.view_filter() + def filter_by_city(self, city: str): + return self.df['city'] == city + + +async def main(): + exporeter = OTLPSpanExporter("http://localhost:4317", insecure=True) + provider = TracerProvider(resource=Resource({"service.name": "db-ally"})) + processor = BatchSpanProcessor(exporeter) + provider.add_span_processor(processor) + handler = OtelEventHandler(provider) + + df = pd.DataFrame({ + "name": ["Alice", "Bob", "Charlie", "David", "Eve"], + "city": ["New York", "Los Angeles", "Chicago", "Houston", "Phoenix"], + }) + + llm = LiteLLM(model_name="gpt-4o") + collection = dbally.create_collection("clients", llm=llm, event_handlers=[handler], nl_responder=NLResponder(llm)) + collection.add(ClientView, lambda: ClientView(df)) + + result = await collection.ask("What clients are from Huston?", return_natural_response=True) + print(result) + + +if __name__ == '__main__': + asyncio.run(main()) +``` \ No newline at end of file diff --git a/docs/reference/event_handlers/otel_handler.md b/docs/reference/event_handlers/otel_handler.md new file mode 100644 index 00000000..1d996834 --- /dev/null +++ b/docs/reference/event_handlers/otel_handler.md @@ -0,0 +1,3 @@ +# OtelEventHandler + +::: dbally.audit.OtelEventHandler \ No newline at end of file diff --git a/mkdocs.yml b/mkdocs.yml index 59cdef42..0129b1ff 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -34,6 +34,7 @@ nav: - how-to/update_similarity_indexes.md - how-to/visualize_views.md - how-to/log_runs_to_langsmith.md + - how-to/trace_runs_with_otel.md - how-to/create_custom_event_handler.md - how-to/openai_assistants_integration.md - API Reference: @@ -54,6 +55,7 @@ nav: - reference/event_handlers/index.md - reference/event_handlers/cli_handler.md - reference/event_handlers/langsmith_handler.md + - reference/event_handlers/otel_handler.md - View Selection: - reference/view_selection/index.md - reference/view_selection/llm_view_selector.md diff --git a/setup.cfg b/setup.cfg index 7a162c5b..0b9683e1 100644 --- a/setup.cfg +++ b/setup.cfg @@ -39,6 +39,7 @@ install_requires = tabulate>=0.9.0 click~=8.1.7 numpy>=1.24.0 + opentelemetry-api>=1.0.0 [options.extras_require] litellm = diff --git a/src/dbally/audit/__init__.py b/src/dbally/audit/__init__.py index 73253f71..96f57f0e 100644 --- a/src/dbally/audit/__init__.py +++ b/src/dbally/audit/__init__.py @@ -7,14 +7,15 @@ except ImportError: pass +from .event_handlers.otel_event_handler import OtelEventHandler from .event_tracker import EventTracker -from .events import Event, LLMEvent, RequestEnd, RequestStart, SimilarityEvent +from .events import LLMEvent, RequestEnd, RequestStart, SimilarityEvent from .spans import EventSpan __all__ = [ "CLIEventHandler", "LangSmithEventHandler", - "Event", + "OtelEventHandler", "EventHandler", "EventTracker", "EventSpan", diff --git a/src/dbally/audit/event_handlers/otel_event_handler.py b/src/dbally/audit/event_handlers/otel_event_handler.py new file mode 100644 index 00000000..00a106a2 --- /dev/null +++ b/src/dbally/audit/event_handlers/otel_event_handler.py @@ -0,0 +1,220 @@ +import json +from dataclasses import dataclass +from typing import Any, Callable, Optional + +from opentelemetry import trace +from opentelemetry.trace import Span, SpanKind, StatusCode, TracerProvider +from opentelemetry.util.types import AttributeValue + +from dbally.audit.event_handlers.base import EventHandler +from dbally.audit.events import Event, LLMEvent, RequestEnd, RequestStart, SimilarityEvent + +TRACER_NAME = "db-ally.events" +FORBIDDEN_CONTEXT_KEYS = {"filter_mask"} + +TransformFn = Optional[Callable[[Any], Optional[AttributeValue]]] + + +def _optional_str(value: Optional[any]) -> Optional[str]: + return None if value is None else str(value) + + +@dataclass +class SpanHandler: + """Handles span attributes and lifecycle""" + + span: Span + record_inputs: bool + record_outputs: bool + + def set(self, key: str, value: Optional[Any], transform: TransformFn = None) -> "SpanHandler": + """ + Sets a value as span attribute under given key if the value exists. Optionally one can add transform function to + change value from any to valid OpenTelemetry attribute type. + + Args: + key: attribute name + value: attribute value. If None, the value is not set + transform: optional function to transform from Any to valid OTel AttributeValue + + Returns: + self, for chaining calls + """ + value = value if transform is None else transform(value) + if value is not None: + self.span.set_attribute(key, value) + + return self + + def set_input(self, key: str, value: Optional[Any], transform: TransformFn = None) -> "SpanHandler": + """ + Sets a value, that is used as model input, under given key if the value exists. If the class does not record + inputs, then the value is not set. Optionally one can add transform function to change value from any to valid + OpenTelemetry attribute type. + + Args: + key: attribute name + value: attribute value. If None, the value is not set. If record_inputs is False, the value is not set. + transform: optional function to transform from Any to valid OTel AttributeValue + + Returns: + self, for chaining calls + """ + value = value if transform is None else transform(value) + if value is not None and self.record_inputs: + self.span.set_attribute(key, value) + + return self + + def set_output(self, key: str, value: Optional[Any], transform: TransformFn = None) -> "SpanHandler": + """ + Sets a value, that is the model output under, given key if the value exists. If the class does not record + inputs, then the value is not set. Optionally one can add transform function to change value from any to valid + OpenTelemetry attribute type. + + Args: + key: attribute name + value: attribute value. If None, the value is not set. If record_output is False, the value is not set. + transform: optional function to transform from Any to valid OTel AttributeValue + + Returns: + self, for chaining calls + """ + value = value if transform is None else transform(value) + if value is not None and self.record_outputs: + self.span.set_attribute(key, value) + + return self + + def end_succesfully(self) -> None: + """Sets status of the span to OK and ends the span with current time""" + self.span.set_status(StatusCode.OK) + self.span.end() + + +class OtelEventHandler(EventHandler[SpanHandler, SpanHandler]): + """ + This handler emits OpenTelemetry spans for recorded events. + """ + + def __init__( + self, provider: Optional[TracerProvider] = None, record_inputs: bool = True, record_outputs: bool = True + ) -> None: + """ + Initialize OtelEventHandler. By default, it will try to use globaly configured TracerProvider. Pass it + explicitly if you want custom implementation, or you do not use OTel auto-instrumentation. + + To comply with the + [OTel Semantic Conventions](https://opentelemetry.io/docs/specs/semconv/gen-ai/llm-spans/#configuration) + recording of inputs and outputs can be disabled. + + Args: + provider: Optional tracer provider. By default global provider is used. + record_inputs: if true (default) all inputs are recorded as span attributes. Depending on usecase it maybe + turned off, to save resources and improve performance. + record_outputs: if true (default) all outputs are recorded as span attributes. Depending on usecase it + maybe turned off, to save resources and improve performance. + """ + self.record_inputs = record_inputs + self.record_outputs = record_outputs + if provider is None: + self.tracer = trace.get_tracer(TRACER_NAME) + else: + self.tracer = provider.get_tracer(TRACER_NAME) + + def _handle_span(self, span: Span) -> SpanHandler: + return SpanHandler(span, self.record_inputs, self.record_outputs) + + async def request_start(self, user_request: RequestStart) -> SpanHandler: + """ + Initializes new OTel Span as a parent. + + Args: + user_request: The start of the request. + + Returns: + span object as a parent for all subsequent events for this request + """ + with self.tracer.start_as_current_span("request", end_on_exit=False, kind=SpanKind.SERVER) as span: + return ( + self._handle_span(span) + .set("db-ally.user.collection", user_request.collection_name) + .set_input("db-ally.user.question", user_request.question) + ) + + async def event_start(self, event: Event, request_context: SpanHandler) -> SpanHandler: + """ + Starts a new event in a system as a span. Uses request span as a parent. + + Args: + event: Event to register + request_context: Parent span for this event + + Returns: + span object capturing start of execution for this event + + Raises: + ValueError: it is thrown when unknown event type is passed as argument + """ + if isinstance(event, LLMEvent): + with self._new_child_span(request_context, "llm") as span: + return ( + self._handle_span(span) + .set("db-ally.llm.type", event.type) + .set_input("db-ally.llm.prompts", json.dumps(event.prompt)) + ) + + if isinstance(event, SimilarityEvent): + with self._new_child_span(request_context, "similarity") as span: + return ( + self._handle_span(span) + .set("db-ally.similarity.store", event.store) + .set("db-ally.similarity.fetcher", event.fetcher) + .set_input("db-ally.similarity.input", event.input_value) + ) + + raise ValueError(f"Unsuported event: {type(event)}") + + async def event_end(self, event: Optional[Event], request_context: SpanHandler, event_context: SpanHandler) -> None: + """ + Finalizes execution of the event, ending a span for this event. + + Args: + event: optional event information + request_context: parent span + event_context: event span + """ + + if isinstance(event, LLMEvent): + event_context.set("db-ally.llm.response-tokes", event.completion_tokens).set_output( + "db-ally.llm.response", event.response + ) + + if isinstance(event, SimilarityEvent) and self.record_outputs: + event_context.set("db-ally.similarity.output", event.output_value) + + event_context.end_succesfully() + + async def request_end(self, output: RequestEnd, request_context: SpanHandler) -> None: + """ + Finalizes entire request, ending the span for this request. + + Args: + output: output generated for this request + request_context: span to be closed + """ + request_context.set_output("db-ally.result.textual", output.result.textual_response).set( + "db-ally.result.execution-time", output.result.execution_time + ).set("db-ally.result.execution-time-view", output.result.execution_time_view).set( + "db-ally.result.view-name", output.result.view_name + ) + + for key, value in output.result.context.items(): + if key not in FORBIDDEN_CONTEXT_KEYS: + request_context.set(f"db-ally.result.context.{key}", value, transform=_optional_str) + + request_context.end_succesfully() + + def _new_child_span(self, parent: SpanHandler, name: str): + context = trace.set_span_in_context(parent.span) + return self.tracer.start_as_current_span(name, context=context, end_on_exit=False, kind=SpanKind.CLIENT) diff --git a/src/dbally/iql/_exceptions.py b/src/dbally/iql/_exceptions.py index 7df08187..99279890 100644 --- a/src/dbally/iql/_exceptions.py +++ b/src/dbally/iql/_exceptions.py @@ -1,5 +1,5 @@ import ast -from typing import Optional, Union +from typing import List, Optional from dbally.exceptions import DbAllyError @@ -7,26 +7,65 @@ class IQLError(DbAllyError): """Base exception for all IQL parsing related exceptions.""" - def __init__(self, message: str, node: Union[ast.stmt, ast.expr], source: str) -> None: - message = message + ": " + source[node.col_offset : node.end_col_offset] - + def __init__(self, message: str, source: str) -> None: super().__init__(message) - self.node = node self.source = source -class IQLArgumentParsingError(IQLError): +class IQLSyntaxError(IQLError): + """Raised when IQL syntax is invalid.""" + + def __init__(self, source: str) -> None: + message = f"Syntax error in: {source}" + super().__init__(message, source) + + +class IQLEmptyExpressionError(IQLError): + """Raised when IQL expression is empty.""" + + def __init__(self, source: str) -> None: + message = "Empty IQL expression" + super().__init__(message, source) + + +class IQLMultipleExpressionsError(IQLError): + """Raised when IQL contains multiple expressions.""" + + def __init__(self, nodes: List[ast.stmt], source: str) -> None: + message = "Multiple expressions or statements in IQL are not supported" + super().__init__(message, source) + self.nodes = nodes + + +class IQLExpressionError(IQLError): + """Raised when IQL expression is invalid.""" + + def __init__(self, message: str, node: ast.expr, source: str) -> None: + message = f"{message}: {source[node.col_offset : node.end_col_offset]}" + super().__init__(message, source) + self.node = node + + +class IQLNoExpressionError(IQLExpressionError): + """Raised when IQL expression is not found.""" + + def __init__(self, node: ast.stmt, source: str) -> None: + message = "No expression found in IQL" + super().__init__(message, node, source) + + +class IQLArgumentParsingError(IQLExpressionError): """Raised when an argument cannot be parsed into a valid IQL.""" - def __init__(self, node: Union[ast.stmt, ast.expr], source: str) -> None: + def __init__(self, node: ast.expr, source: str) -> None: message = "Not a valid IQL argument" super().__init__(message, node, source) -class IQLUnsupportedSyntaxError(IQLError): +class IQLUnsupportedSyntaxError(IQLExpressionError): """Raised when trying to parse an unsupported syntax.""" - def __init__(self, node: Union[ast.stmt, ast.expr], source: str, context: Optional[str] = None) -> None: + def __init__(self, node: ast.expr, source: str, context: Optional[str] = None) -> None: node_name = node.__class__.__name__ message = f"{node_name} syntax is not supported in IQL" @@ -37,7 +76,7 @@ def __init__(self, node: Union[ast.stmt, ast.expr], source: str, context: Option super().__init__(message, node, source) -class IQLFunctionNotExists(IQLError): +class IQLFunctionNotExists(IQLExpressionError): """Raised when IQL contains function call to a function that not exists.""" def __init__(self, node: ast.Name, source: str) -> None: @@ -45,5 +84,13 @@ def __init__(self, node: ast.Name, source: str) -> None: super().__init__(message, node, source) -class IQLArgumentValidationError(IQLError): +class IQLIncorrectNumberArgumentsError(IQLExpressionError): + """Raised when IQL contains too many arguments for a function.""" + + def __init__(self, node: ast.Call, source: str) -> None: + message = f"The method {node.func.id} has incorrect number of arguments" + super().__init__(message, node, source) + + +class IQLArgumentValidationError(IQLExpressionError): """Raised when argument is not valid for a given method.""" diff --git a/src/dbally/iql/_processor.py b/src/dbally/iql/_processor.py index 8127ddfe..7ff81d8f 100644 --- a/src/dbally/iql/_processor.py +++ b/src/dbally/iql/_processor.py @@ -6,8 +6,12 @@ from dbally.iql._exceptions import ( IQLArgumentParsingError, IQLArgumentValidationError, - IQLError, + IQLEmptyExpressionError, IQLFunctionNotExists, + IQLIncorrectNumberArgumentsError, + IQLMultipleExpressionsError, + IQLNoExpressionError, + IQLSyntaxError, IQLUnsupportedSyntaxError, ) from dbally.iql._type_validators import validate_arg_type @@ -33,21 +37,28 @@ async def process(self) -> syntax.Node: Process IQL string to root IQL.Node. Returns: - IQL.Node which is root of the tree representing IQL query. + IQL node which is root of the tree representing IQL query. Raises: - IQLError: if parsing fails. + IQLError: If parsing fails. """ self.source = self._to_lower_except_in_quotes(self.source, ["AND", "OR", "NOT"]) - ast_tree = ast.parse(self.source) - first_element = ast_tree.body[0] + try: + ast_tree = ast.parse(self.source) + except (SyntaxError, ValueError) as exc: + raise IQLSyntaxError(self.source) from exc - if not isinstance(first_element, ast.Expr): - raise IQLError("Not a valid IQL expression", first_element, self.source) + if not ast_tree.body: + raise IQLEmptyExpressionError(self.source) - root = await self._parse_node(first_element.value) - return root + if len(ast_tree.body) > 1: + raise IQLMultipleExpressionsError(ast_tree.body, self.source) + + if not isinstance(ast_tree.body[0], ast.Expr): + raise IQLNoExpressionError(ast_tree.body[0], self.source) + + return await self._parse_node(ast_tree.body[0].value) async def _parse_node(self, node: Union[ast.expr, ast.Expr]) -> syntax.Node: if isinstance(node, ast.BoolOp): @@ -82,7 +93,7 @@ async def _parse_call(self, node: ast.Call) -> syntax.FunctionCall: args = [] if len(func_def.parameters) != len(node.args): - raise ValueError(f"The method {func.id} has incorrect number of arguments") + raise IQLIncorrectNumberArgumentsError(node, self.source) for arg, arg_def in zip(node.args, func_def.parameters): arg_value = self._parse_arg(arg) diff --git a/src/dbally/iql/_query.py b/src/dbally/iql/_query.py index c2131a57..dd831a91 100644 --- a/src/dbally/iql/_query.py +++ b/src/dbally/iql/_query.py @@ -33,11 +33,15 @@ async def parse( Parse IQL string to IQLQuery object. Args: - source: IQL string that needs to be parsed - allowed_functions: list of IQL functions that are allowed for this query - event_tracker: EventTracker object to track events + source: IQL string that needs to be parsed. + allowed_functions: List of IQL functions that are allowed for this query. + event_tracker: EventTracker object to track events. + Returns: - IQLQuery object + IQLQuery object. + + Raises: + IQLError: If parsing fails. """ root = await IQLProcessor(source, allowed_functions, event_tracker=event_tracker).process() return cls(root=root, source=source) diff --git a/src/dbally/iql_generator/iql_generator.py b/src/dbally/iql_generator/iql_generator.py index 4d028fa9..4b6dbe63 100644 --- a/src/dbally/iql_generator/iql_generator.py +++ b/src/dbally/iql_generator/iql_generator.py @@ -57,6 +57,9 @@ async def generate_iql( Returns: Generated IQL query. + + Raises: + IQLError: If IQL generation fails after all retries. """ prompt_format = IQLGenerationPromptFormat( question=question, @@ -66,7 +69,7 @@ async def generate_iql( formatted_prompt = self._prompt_template.format_prompt(prompt_format) - for _ in range(n_retries + 1): + for retry in range(n_retries + 1): try: response = await self._llm.generate_text( prompt=formatted_prompt, @@ -82,5 +85,7 @@ async def generate_iql( event_tracker=event_tracker, ) except IQLError as exc: + if retry == n_retries: + raise exc formatted_prompt = formatted_prompt.add_assistant_message(response) formatted_prompt = formatted_prompt.add_user_message(ERROR_MESSAGE.format(error=exc)) diff --git a/src/dbally/views/structured.py b/src/dbally/views/structured.py index db81894b..89d8d19c 100644 --- a/src/dbally/views/structured.py +++ b/src/dbally/views/structured.py @@ -69,6 +69,9 @@ async def ask( Returns: The result of the query. + + Raises: + IQLError: If the generated IQL query is not valid. """ iql_generator = self.get_iql_generator(llm) agg_formatter = self.get_agg_formatter(llm) diff --git a/tests/integration/test_llm_options.py b/tests/integration/test_llm_options.py index 411ffe14..62a6766d 100644 --- a/tests/integration/test_llm_options.py +++ b/tests/integration/test_llm_options.py @@ -1,4 +1,4 @@ -from unittest.mock import ANY, AsyncMock, call +from unittest.mock import ANY, AsyncMock, call, patch import pytest @@ -24,11 +24,12 @@ async def test_llm_options_propagation(): collection.add(MockView1) collection.add(MockView2) - await collection.ask( - question="Mock question", - return_natural_response=True, - llm_options=custom_options, - ) + with patch("dbally.iql.IQLQuery.parse", AsyncMock()): + await collection.ask( + question="Mock question", + return_natural_response=True, + llm_options=custom_options, + ) assert llm.client.call.call_count == 4 diff --git a/tests/unit/audit/__init__.py b/tests/unit/audit/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/audit/event_handlers/__init__.py b/tests/unit/audit/event_handlers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/audit/event_handlers/test_otel_event_handler.py b/tests/unit/audit/event_handlers/test_otel_event_handler.py new file mode 100644 index 00000000..c7a8a7a6 --- /dev/null +++ b/tests/unit/audit/event_handlers/test_otel_event_handler.py @@ -0,0 +1,117 @@ +from typing import Dict, Mapping, Optional, Union + +from opentelemetry import trace +from opentelemetry.trace import Span, StatusCode +from opentelemetry.util.types import AttributeValue + +from dbally.audit.event_handlers.otel_event_handler import SpanHandler + + +class MockSpan(Span): + def __init__(self) -> None: + super().__init__() + self.attributes = {} + self.status = StatusCode.UNSET + self.is_finished = False + + def end(self, end_time: Optional[int] = None) -> None: + self.is_finished = True + + def get_span_context(self) -> trace.SpanContext: + raise NotImplementedError + + def set_attributes(self, attributes: Dict[str, AttributeValue]) -> None: + self.attributes.update(attributes) + + def set_attribute(self, key: str, value: AttributeValue) -> None: + self.attributes[key] = value + + def add_event( + self, name: str, attributes: Optional[Mapping[str, AttributeValue]] = None, timestamp: Optional[int] = None + ) -> None: + raise NotImplementedError + + def update_name(self, name: str) -> None: + raise NotImplementedError + + def is_recording(self) -> bool: + raise NotImplementedError + + def set_status(self, status: Union[trace.Status, StatusCode], description: Optional[str] = None) -> None: + self.status = status.status_code if isinstance(status, trace.Status) else status + + def record_exception( + self, + exception: BaseException, + attributes: Optional[Mapping[str, AttributeValue]] = None, + timestamp: Optional[int] = None, + escaped: bool = False, + ) -> None: + raise NotImplementedError + + +def test_span_handler_sets_all(): + span = MockSpan() + + handler = SpanHandler(span, record_inputs=True, record_outputs=True) + handler.set("standard", "1") + handler.set_input("inputs", "2") + handler.set_output("outputs", "3") + handler.end_succesfully() + + assert span.attributes.get("standard") == "1" + assert span.attributes.get("inputs") == "2" + assert span.attributes.get("outputs") == "3" + assert span.status == StatusCode.OK + assert span.is_finished + + +def test_span_handler_sets_without_input(): + span = MockSpan() + + handler = SpanHandler(span, record_inputs=False, record_outputs=True) + handler.set("standard", "1") + handler.set_input("inputs", "2") + handler.set_output("outputs", "3") + handler.end_succesfully() + + assert span.attributes.get("standard") == "1" + assert span.attributes.get("inputs") is None + assert span.attributes.get("outputs") == "3" + assert span.status == StatusCode.OK + assert span.is_finished + + +def test_span_handler_sets_without_outputs(): + span = MockSpan() + + handler = SpanHandler(span, record_inputs=True, record_outputs=False) + handler.set("standard", "1") + handler.set_input("inputs", "2") + handler.set_output("outputs", "3") + handler.end_succesfully() + + assert span.attributes.get("standard") == "1" + assert span.attributes.get("inputs") == "2" + assert span.attributes.get("outputs") is None + assert span.status == StatusCode.OK + assert span.is_finished + + +def test_span_handler_sets_with_transformation(): + span = MockSpan() + + def transform_fn(x: str): + return None if x == "foo" else x.upper() + + handler = SpanHandler(span, record_inputs=True, record_outputs=True) + handler.set("standard", "foo", transform=transform_fn) + handler.set_input("inputs", "bar", transform=transform_fn) + handler.set_output("outputs", "baz", transform=transform_fn) + handler.end_succesfully() + + assert span.attributes.get("standard") is None + assert span.attributes.get("inputs") == "BAR" + assert span.attributes.get("outputs") == "BAZ" + assert span.status == StatusCode.OK + assert span.is_finished diff --git a/tests/unit/iql/test_iql_parser.py b/tests/unit/iql/test_iql_parser.py index 94b66e28..0d018f4e 100644 --- a/tests/unit/iql/test_iql_parser.py +++ b/tests/unit/iql/test_iql_parser.py @@ -4,7 +4,15 @@ import pytest from dbally.iql import IQLArgumentParsingError, IQLQuery, IQLUnsupportedSyntaxError, syntax -from dbally.iql._exceptions import IQLArgumentValidationError, IQLFunctionNotExists +from dbally.iql._exceptions import ( + IQLArgumentValidationError, + IQLEmptyExpressionError, + IQLFunctionNotExists, + IQLIncorrectNumberArgumentsError, + IQLMultipleExpressionsError, + IQLNoExpressionError, + IQLSyntaxError, +) from dbally.iql._processor import IQLProcessor from dbally.views.exposed_functions import ExposedFunction, MethodParamWithTyping @@ -68,6 +76,78 @@ async def test_iql_parser_arg_error(): assert exc_info.match(re.escape("Not a valid IQL argument: lambda x: x + 1")) +async def test_iql_parser_syntax_error(): + with pytest.raises(IQLSyntaxError) as exc_info: + await IQLQuery.parse( + "filter_by_age(", + allowed_functions=[ + ExposedFunction( + name="filter_by_age", + description="", + parameters=[ + MethodParamWithTyping(name="age", type=int), + ], + ), + ], + ) + + assert exc_info.match(re.escape("Syntax error in: filter_by_age(")) + + +async def test_iql_parser_multiple_expression_error(): + with pytest.raises(IQLMultipleExpressionsError) as exc_info: + await IQLQuery.parse( + "filter_by_age\nfilter_by_age", + allowed_functions=[ + ExposedFunction( + name="filter_by_age", + description="", + parameters=[ + MethodParamWithTyping(name="age", type=int), + ], + ), + ], + ) + + assert exc_info.match(re.escape("Multiple expressions or statements in IQL are not supported")) + + +async def test_iql_parser_empty_expression_error(): + with pytest.raises(IQLEmptyExpressionError) as exc_info: + await IQLQuery.parse( + "", + allowed_functions=[ + ExposedFunction( + name="filter_by_age", + description="", + parameters=[ + MethodParamWithTyping(name="age", type=int), + ], + ), + ], + ) + + assert exc_info.match(re.escape("Empty IQL expression")) + + +async def test_iql_parser_no_expression_error(): + with pytest.raises(IQLNoExpressionError) as exc_info: + await IQLQuery.parse( + "import filter_by_age", + allowed_functions=[ + ExposedFunction( + name="filter_by_age", + description="", + parameters=[ + MethodParamWithTyping(name="age", type=int), + ], + ), + ], + ) + + assert exc_info.match(re.escape("No expression found in IQL: import filter_by_age")) + + async def test_iql_parser_unsupported_syntax_error(): with pytest.raises(IQLUnsupportedSyntaxError) as exc_info: await IQLQuery.parse( @@ -104,6 +184,26 @@ async def test_iql_parser_method_not_exists(): assert exc_info.match(re.escape("Function filter_by_how_old_somebody_is not exists: filter_by_how_old_somebody_is")) +async def test_iql_parser_incorrect_number_of_arguments_fail(): + with pytest.raises(IQLIncorrectNumberArgumentsError) as exc_info: + await IQLQuery.parse( + "filter_by_age('too old', 40)", + allowed_functions=[ + ExposedFunction( + name="filter_by_age", + description="", + parameters=[ + MethodParamWithTyping(name="age", type=int), + ], + ), + ], + ) + + assert exc_info.match( + re.escape("The method filter_by_age has incorrect number of arguments: filter_by_age('too old', 40)") + ) + + async def test_iql_parser_argument_validation_fail(): with pytest.raises(IQLArgumentValidationError) as exc_info: await IQLQuery.parse( diff --git a/tests/unit/test_iql_generator.py b/tests/unit/test_iql_generator.py index 4a79c394..f98b5d3d 100644 --- a/tests/unit/test_iql_generator.py +++ b/tests/unit/test_iql_generator.py @@ -1,6 +1,6 @@ # mypy: disable-error-code="empty-body" -from unittest.mock import AsyncMock, Mock, patch +from unittest.mock import AsyncMock, patch import pytest import sqlalchemy @@ -87,30 +87,51 @@ async def test_iql_generation(iql_generator: IQLGenerator, event_tracker: EventT @pytest.mark.asyncio -async def test_iql_generation_error_handling( +async def test_iql_generation_error_escalation_after_max_retires( iql_generator: IQLGenerator, event_tracker: EventTracker, view: MockView, ) -> None: filters = view.list_filters() - - mock_node = Mock(col_offset=0, end_col_offset=-1) - errors = [ - IQLError("err1", mock_node, "src1"), - IQLError("err2", mock_node, "src2"), - IQLError("err3", mock_node, "src3"), - IQLError("err4", mock_node, "src4"), + responses = [ + IQLError("err1", "src1"), + IQLError("err2", "src2"), + IQLError("err3", "src3"), + IQLError("err4", "src4"), ] - with patch("dbally.iql.IQLQuery.parse", AsyncMock(return_value="filter_by_id(1)")) as mock_parse: - mock_parse.side_effect = errors + with patch("dbally.iql.IQLQuery.parse", AsyncMock(side_effect=responses)), pytest.raises(IQLError): iql = await iql_generator.generate_iql( question="Mock_question", filters=filters, event_tracker=event_tracker, + n_retries=3, ) assert iql is None assert iql_generator._llm.generate_text.call_count == 4 for i, arg in enumerate(iql_generator._llm.generate_text.call_args_list[1:], start=1): assert f"err{i}" in arg[1]["prompt"].chat[-1]["content"] + + +@pytest.mark.asyncio +async def test_iql_generation_response_after_max_retries( + iql_generator: IQLGenerator, + event_tracker: EventTracker, + view: MockView, +) -> None: + filters = view.list_filters() + responses = [IQLError("err1", "src1"), IQLError("err2", "src2"), IQLError("err3", "src3"), "filter_by_id(1)"] + + with patch("dbally.iql.IQLQuery.parse", AsyncMock(side_effect=responses)): + iql = await iql_generator.generate_iql( + question="Mock_question", + filters=filters, + event_tracker=event_tracker, + n_retries=3, + ) + + assert iql == "filter_by_id(1)" + assert iql_generator._llm.generate_text.call_count == 4 + for i, arg in enumerate(iql_generator._llm.generate_text.call_args_list[1:], start=1): + assert f"err{i}" in arg[1]["prompt"].chat[-1]["content"]