Skip to content

Commit

Permalink
Fixed TypeError in LangfuseTrace (#1184)
Browse files Browse the repository at this point in the history
* Added parent_span functionality in trace method

* solved PR comments

* Readded "end()" for solving Latency issues

* chore: fix ruff linting

* Handle multiple runs

* Fix indentation and span closing

* Fix tests

---------

Co-authored-by: Vladimir Blagojevic <[email protected]>
Co-authored-by: Silvano Cerza <[email protected]>
  • Loading branch information
3 people authored Nov 18, 2024
1 parent 67e08d0 commit 1e3ac82
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 58 deletions.
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import contextlib
import logging
import os
from contextvars import ContextVar
from datetime import datetime
from typing import Any, Dict, Iterator, Optional, Union
from typing import Any, Dict, Iterator, List, Optional, Union

from haystack import logging
from haystack.components.generators.openai_utils import _convert_message_to_openai_format
from haystack.dataclasses import ChatMessage
from haystack.tracing import Span, Tracer, tracer
Expand Down Expand Up @@ -32,6 +33,17 @@
]
_ALL_SUPPORTED_GENERATORS = _SUPPORTED_GENERATORS + _SUPPORTED_CHAT_GENERATORS

# These are the keys used by Haystack for traces and span.
# We keep them here to avoid making typos when using them.
_PIPELINE_RUN_KEY = "haystack.pipeline.run"
_COMPONENT_NAME_KEY = "haystack.component.name"
_COMPONENT_TYPE_KEY = "haystack.component.type"
_COMPONENT_OUTPUT_KEY = "haystack.component.output"

# Context var used to keep track of tracing related info.
# This mainly useful for parents spans.
tracing_context_var: ContextVar[Dict[Any, Any]] = ContextVar("tracing_context", default={})


class LangfuseSpan(Span):
"""
Expand Down Expand Up @@ -86,7 +98,7 @@ def set_content_tag(self, key: str, value: Any) -> None:

self._data[key] = value

def raw_span(self) -> Any:
def raw_span(self) -> "Union[langfuse.client.StatefulSpanClient, langfuse.client.StatefulTraceClient]":
"""
Return the underlying span instance.
Expand Down Expand Up @@ -115,41 +127,57 @@ def __init__(self, tracer: "langfuse.Langfuse", name: str = "Haystack", public:
and only accessible to the Langfuse account owner.
"""
self._tracer = tracer
self._context: list[LangfuseSpan] = []
self._context: List[LangfuseSpan] = []
self._name = name
self._public = public
self.enforce_flush = os.getenv(HAYSTACK_LANGFUSE_ENFORCE_FLUSH_ENV_VAR, "true").lower() == "true"

@contextlib.contextmanager
def trace(self, operation_name: str, tags: Optional[Dict[str, Any]] = None) -> Iterator[Span]:
"""
Start and manage a new trace span.
:param operation_name: The name of the operation.
:param tags: A dictionary of tags to attach to the span.
:return: A context manager yielding the span.
"""
def trace(
self, operation_name: str, tags: Optional[Dict[str, Any]] = None, parent_span: Optional[Span] = None
) -> Iterator[Span]:
tags = tags or {}
span_name = tags.get("haystack.component.name", operation_name)

if tags.get("haystack.component.type") in _ALL_SUPPORTED_GENERATORS:
span = LangfuseSpan(self.current_span().raw_span().generation(name=span_name))
span_name = tags.get(_COMPONENT_NAME_KEY, operation_name)

# Create new span depending whether there's a parent span or not
if not parent_span:
if operation_name != _PIPELINE_RUN_KEY:
logger.warning(
"Creating a new trace without a parent span is not recommended for operation '{operation_name}'.",
operation_name=operation_name,
)
# Create a new trace if no parent span is provided
span = LangfuseSpan(
self._tracer.trace(
name=self._name,
public=self._public,
id=tracing_context_var.get().get("trace_id"),
user_id=tracing_context_var.get().get("user_id"),
session_id=tracing_context_var.get().get("session_id"),
tags=tracing_context_var.get().get("tags"),
version=tracing_context_var.get().get("version"),
)
)
elif tags.get(_COMPONENT_TYPE_KEY) in _ALL_SUPPORTED_GENERATORS:
span = LangfuseSpan(parent_span.raw_span().generation(name=span_name))
else:
span = LangfuseSpan(self.current_span().raw_span().span(name=span_name))
span = LangfuseSpan(parent_span.raw_span().span(name=span_name))

self._context.append(span)
span.set_tags(tags)

yield span

if tags.get("haystack.component.type") in _SUPPORTED_GENERATORS:
meta = span._data.get("haystack.component.output", {}).get("meta")
# Update span metadata based on component type
if tags.get(_COMPONENT_TYPE_KEY) in _SUPPORTED_GENERATORS:
# Haystack returns one meta dict for each message, but the 'usage' value
# is always the same, let's just pick the first item
meta = span._data.get(_COMPONENT_OUTPUT_KEY, {}).get("meta")
if meta:
# Haystack returns one meta dict for each message, but the 'usage' value
# is always the same, let's just pick the first item
m = meta[0]
span._span.update(usage=m.get("usage") or None, model=m.get("model"))
elif tags.get("haystack.component.type") in _SUPPORTED_CHAT_GENERATORS:
replies = span._data.get("haystack.component.output", {}).get("replies")
elif tags.get(_COMPONENT_TYPE_KEY) in _SUPPORTED_CHAT_GENERATORS:
replies = span._data.get(_COMPONENT_OUTPUT_KEY, {}).get("replies")
if replies:
meta = replies[0].meta
completion_start_time = meta.get("completion_start_time")
Expand All @@ -165,36 +193,24 @@ def trace(self, operation_name: str, tags: Optional[Dict[str, Any]] = None) -> I
completion_start_time=completion_start_time,
)

pipeline_input = tags.get("haystack.pipeline.input_data", None)
if pipeline_input:
span._span.update(input=tags["haystack.pipeline.input_data"])
pipeline_output = tags.get("haystack.pipeline.output_data", None)
if pipeline_output:
span._span.update(output=tags["haystack.pipeline.output_data"])

span.raw_span().end()
raw_span = span.raw_span()
if isinstance(raw_span, langfuse.client.StatefulSpanClient):
raw_span.end()
self._context.pop()

if len(self._context) == 1:
# The root span has to be a trace, which need to be removed from the context after the pipeline run
self._context.pop()

if self.enforce_flush:
self.flush()
if self.enforce_flush:
self.flush()

def flush(self):
self._tracer.flush()

def current_span(self) -> Span:
def current_span(self) -> Optional[Span]:
"""
Return the currently active span.
Return the current active span.
:return: The currently active span.
:return: The current span if available, else None.
"""
if not self._context:
# The root span has to be a trace
self._context.append(LangfuseSpan(self._tracer.trace(name=self._name, public=self._public)))
return self._context[-1]
return self._context[-1] if self._context else None

def get_trace_url(self) -> str:
"""
Expand Down
2 changes: 1 addition & 1 deletion integrations/langfuse/tests/test_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def test_create_new_span(self):
tracer = LangfuseTracer(tracer=mock_tracer, name="Haystack", public=False)

with tracer.trace("operation_name", tags={"tag1": "value1", "tag2": "value2"}) as span:
assert len(tracer._context) == 2, "The trace span should have been added to the the root context span"
assert len(tracer._context) == 1, "The trace span should have been added to the the root context span"
assert span.raw_span().operation_name == "operation_name"
assert span.raw_span().metadata == {"tag1": "value1", "tag2": "value2"}

Expand Down
31 changes: 17 additions & 14 deletions integrations/langfuse/tests/test_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,25 +52,28 @@ def test_tracing_integration(llm_class, env_var, expected_trace):
assert "Berlin" in response["llm"]["replies"][0].content
assert response["tracer"]["trace_url"]

# add a random delay between 1 and 3 seconds to make sure the trace is flushed
# and that the trace is available in Langfuse when we fetch it below
time.sleep(random.uniform(1, 3))

url = "https://cloud.langfuse.com/api/public/traces/"
trace_url = response["tracer"]["trace_url"]
uuid = os.path.basename(urlparse(trace_url).path)
url = f"https://cloud.langfuse.com/api/public/traces/{uuid}"

try:
response = requests.get(
url + uuid, auth=HTTPBasicAuth(os.environ["LANGFUSE_PUBLIC_KEY"], os.environ["LANGFUSE_SECRET_KEY"])
# Poll the Langfuse API a bit as the trace might not be ready right away
attempts = 5
delay = 1
while attempts >= 0:
res = requests.get(
url, auth=HTTPBasicAuth(os.environ["LANGFUSE_PUBLIC_KEY"], os.environ["LANGFUSE_SECRET_KEY"])
)
assert response.status_code == 200, f"Failed to retrieve data from Langfuse API: {response.status_code}"
if attempts > 0 and res.status_code != 200:
attempts -= 1
time.sleep(delay)
delay *= 2
continue
assert res.status_code == 200, f"Failed to retrieve data from Langfuse API: {res.status_code}"

# check if the trace contains the expected LLM name
assert expected_trace in str(response.content)
assert expected_trace in str(res.content)
# check if the trace contains the expected generation span
assert "GENERATION" in str(response.content)
assert "GENERATION" in str(res.content)
# check if the trace contains the expected user_id
assert "user_42" in str(response.content)
except requests.exceptions.RequestException as e:
pytest.fail(f"Failed to retrieve data from Langfuse API: {e}")
assert "user_42" in str(res.content)
break

0 comments on commit 1e3ac82

Please sign in to comment.