From 1e3ac825be60c39dcc1ea06cb11ff0eac386656c Mon Sep 17 00:00:00 2001 From: alex-stoica Date: Mon, 18 Nov 2024 17:46:28 +0200 Subject: [PATCH] Fixed TypeError in LangfuseTrace (#1184) * Added parent_span functionality in trace method * solved PR comments * Readded "end()" for solving Latency issues * chore: fix ruff linting * Handle multiple runs * Fix indentation and span closing * Fix tests --------- Co-authored-by: Vladimir Blagojevic Co-authored-by: Silvano Cerza --- .../tracing/langfuse/tracer.py | 102 ++++++++++-------- integrations/langfuse/tests/test_tracer.py | 2 +- integrations/langfuse/tests/test_tracing.py | 31 +++--- 3 files changed, 77 insertions(+), 58 deletions(-) diff --git a/integrations/langfuse/src/haystack_integrations/tracing/langfuse/tracer.py b/integrations/langfuse/src/haystack_integrations/tracing/langfuse/tracer.py index c9c8a354e..c1f8d4d93 100644 --- a/integrations/langfuse/src/haystack_integrations/tracing/langfuse/tracer.py +++ b/integrations/langfuse/src/haystack_integrations/tracing/langfuse/tracer.py @@ -1,9 +1,10 @@ import contextlib -import logging import os +from contextvars import ContextVar from datetime import datetime -from typing import Any, Dict, Iterator, Optional, Union +from typing import Any, Dict, Iterator, List, Optional, Union +from haystack import logging from haystack.components.generators.openai_utils import _convert_message_to_openai_format from haystack.dataclasses import ChatMessage from haystack.tracing import Span, Tracer, tracer @@ -32,6 +33,17 @@ ] _ALL_SUPPORTED_GENERATORS = _SUPPORTED_GENERATORS + _SUPPORTED_CHAT_GENERATORS +# These are the keys used by Haystack for traces and span. +# We keep them here to avoid making typos when using them. +_PIPELINE_RUN_KEY = "haystack.pipeline.run" +_COMPONENT_NAME_KEY = "haystack.component.name" +_COMPONENT_TYPE_KEY = "haystack.component.type" +_COMPONENT_OUTPUT_KEY = "haystack.component.output" + +# Context var used to keep track of tracing related info. +# This mainly useful for parents spans. +tracing_context_var: ContextVar[Dict[Any, Any]] = ContextVar("tracing_context", default={}) + class LangfuseSpan(Span): """ @@ -86,7 +98,7 @@ def set_content_tag(self, key: str, value: Any) -> None: self._data[key] = value - def raw_span(self) -> Any: + def raw_span(self) -> "Union[langfuse.client.StatefulSpanClient, langfuse.client.StatefulTraceClient]": """ Return the underlying span instance. @@ -115,41 +127,57 @@ def __init__(self, tracer: "langfuse.Langfuse", name: str = "Haystack", public: and only accessible to the Langfuse account owner. """ self._tracer = tracer - self._context: list[LangfuseSpan] = [] + self._context: List[LangfuseSpan] = [] self._name = name self._public = public self.enforce_flush = os.getenv(HAYSTACK_LANGFUSE_ENFORCE_FLUSH_ENV_VAR, "true").lower() == "true" @contextlib.contextmanager - def trace(self, operation_name: str, tags: Optional[Dict[str, Any]] = None) -> Iterator[Span]: - """ - Start and manage a new trace span. - :param operation_name: The name of the operation. - :param tags: A dictionary of tags to attach to the span. - :return: A context manager yielding the span. - """ + def trace( + self, operation_name: str, tags: Optional[Dict[str, Any]] = None, parent_span: Optional[Span] = None + ) -> Iterator[Span]: tags = tags or {} - span_name = tags.get("haystack.component.name", operation_name) - - if tags.get("haystack.component.type") in _ALL_SUPPORTED_GENERATORS: - span = LangfuseSpan(self.current_span().raw_span().generation(name=span_name)) + span_name = tags.get(_COMPONENT_NAME_KEY, operation_name) + + # Create new span depending whether there's a parent span or not + if not parent_span: + if operation_name != _PIPELINE_RUN_KEY: + logger.warning( + "Creating a new trace without a parent span is not recommended for operation '{operation_name}'.", + operation_name=operation_name, + ) + # Create a new trace if no parent span is provided + span = LangfuseSpan( + self._tracer.trace( + name=self._name, + public=self._public, + id=tracing_context_var.get().get("trace_id"), + user_id=tracing_context_var.get().get("user_id"), + session_id=tracing_context_var.get().get("session_id"), + tags=tracing_context_var.get().get("tags"), + version=tracing_context_var.get().get("version"), + ) + ) + elif tags.get(_COMPONENT_TYPE_KEY) in _ALL_SUPPORTED_GENERATORS: + span = LangfuseSpan(parent_span.raw_span().generation(name=span_name)) else: - span = LangfuseSpan(self.current_span().raw_span().span(name=span_name)) + span = LangfuseSpan(parent_span.raw_span().span(name=span_name)) self._context.append(span) span.set_tags(tags) yield span - if tags.get("haystack.component.type") in _SUPPORTED_GENERATORS: - meta = span._data.get("haystack.component.output", {}).get("meta") + # Update span metadata based on component type + if tags.get(_COMPONENT_TYPE_KEY) in _SUPPORTED_GENERATORS: + # Haystack returns one meta dict for each message, but the 'usage' value + # is always the same, let's just pick the first item + meta = span._data.get(_COMPONENT_OUTPUT_KEY, {}).get("meta") if meta: - # Haystack returns one meta dict for each message, but the 'usage' value - # is always the same, let's just pick the first item m = meta[0] span._span.update(usage=m.get("usage") or None, model=m.get("model")) - elif tags.get("haystack.component.type") in _SUPPORTED_CHAT_GENERATORS: - replies = span._data.get("haystack.component.output", {}).get("replies") + elif tags.get(_COMPONENT_TYPE_KEY) in _SUPPORTED_CHAT_GENERATORS: + replies = span._data.get(_COMPONENT_OUTPUT_KEY, {}).get("replies") if replies: meta = replies[0].meta completion_start_time = meta.get("completion_start_time") @@ -165,36 +193,24 @@ def trace(self, operation_name: str, tags: Optional[Dict[str, Any]] = None) -> I completion_start_time=completion_start_time, ) - pipeline_input = tags.get("haystack.pipeline.input_data", None) - if pipeline_input: - span._span.update(input=tags["haystack.pipeline.input_data"]) - pipeline_output = tags.get("haystack.pipeline.output_data", None) - if pipeline_output: - span._span.update(output=tags["haystack.pipeline.output_data"]) - - span.raw_span().end() + raw_span = span.raw_span() + if isinstance(raw_span, langfuse.client.StatefulSpanClient): + raw_span.end() self._context.pop() - if len(self._context) == 1: - # The root span has to be a trace, which need to be removed from the context after the pipeline run - self._context.pop() - - if self.enforce_flush: - self.flush() + if self.enforce_flush: + self.flush() def flush(self): self._tracer.flush() - def current_span(self) -> Span: + def current_span(self) -> Optional[Span]: """ - Return the currently active span. + Return the current active span. - :return: The currently active span. + :return: The current span if available, else None. """ - if not self._context: - # The root span has to be a trace - self._context.append(LangfuseSpan(self._tracer.trace(name=self._name, public=self._public))) - return self._context[-1] + return self._context[-1] if self._context else None def get_trace_url(self) -> str: """ diff --git a/integrations/langfuse/tests/test_tracer.py b/integrations/langfuse/tests/test_tracer.py index 9ee8e5dc4..42ae1d07d 100644 --- a/integrations/langfuse/tests/test_tracer.py +++ b/integrations/langfuse/tests/test_tracer.py @@ -69,7 +69,7 @@ def test_create_new_span(self): tracer = LangfuseTracer(tracer=mock_tracer, name="Haystack", public=False) with tracer.trace("operation_name", tags={"tag1": "value1", "tag2": "value2"}) as span: - assert len(tracer._context) == 2, "The trace span should have been added to the the root context span" + assert len(tracer._context) == 1, "The trace span should have been added to the the root context span" assert span.raw_span().operation_name == "operation_name" assert span.raw_span().metadata == {"tag1": "value1", "tag2": "value2"} diff --git a/integrations/langfuse/tests/test_tracing.py b/integrations/langfuse/tests/test_tracing.py index 657b6eae1..e5737b861 100644 --- a/integrations/langfuse/tests/test_tracing.py +++ b/integrations/langfuse/tests/test_tracing.py @@ -52,25 +52,28 @@ def test_tracing_integration(llm_class, env_var, expected_trace): assert "Berlin" in response["llm"]["replies"][0].content assert response["tracer"]["trace_url"] - # add a random delay between 1 and 3 seconds to make sure the trace is flushed - # and that the trace is available in Langfuse when we fetch it below - time.sleep(random.uniform(1, 3)) - - url = "https://cloud.langfuse.com/api/public/traces/" trace_url = response["tracer"]["trace_url"] uuid = os.path.basename(urlparse(trace_url).path) + url = f"https://cloud.langfuse.com/api/public/traces/{uuid}" - try: - response = requests.get( - url + uuid, auth=HTTPBasicAuth(os.environ["LANGFUSE_PUBLIC_KEY"], os.environ["LANGFUSE_SECRET_KEY"]) + # Poll the Langfuse API a bit as the trace might not be ready right away + attempts = 5 + delay = 1 + while attempts >= 0: + res = requests.get( + url, auth=HTTPBasicAuth(os.environ["LANGFUSE_PUBLIC_KEY"], os.environ["LANGFUSE_SECRET_KEY"]) ) - assert response.status_code == 200, f"Failed to retrieve data from Langfuse API: {response.status_code}" + if attempts > 0 and res.status_code != 200: + attempts -= 1 + time.sleep(delay) + delay *= 2 + continue + assert res.status_code == 200, f"Failed to retrieve data from Langfuse API: {res.status_code}" # check if the trace contains the expected LLM name - assert expected_trace in str(response.content) + assert expected_trace in str(res.content) # check if the trace contains the expected generation span - assert "GENERATION" in str(response.content) + assert "GENERATION" in str(res.content) # check if the trace contains the expected user_id - assert "user_42" in str(response.content) - except requests.exceptions.RequestException as e: - pytest.fail(f"Failed to retrieve data from Langfuse API: {e}") + assert "user_42" in str(res.content) + break