Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for ttft #1161

Merged
merged 5 commits into from
Nov 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import contextlib
import logging
import os
from datetime import datetime
from typing import Any, Dict, Iterator, Optional, Union

from haystack.components.generators.openai_utils import _convert_message_to_openai_format
Expand All @@ -9,6 +11,8 @@

import langfuse

logger = logging.getLogger(__name__)

HAYSTACK_LANGFUSE_ENFORCE_FLUSH_ENV_VAR = "HAYSTACK_LANGFUSE_ENFORCE_FLUSH"
_SUPPORTED_GENERATORS = [
"AzureOpenAIGenerator",
Expand Down Expand Up @@ -148,7 +152,18 @@ def trace(self, operation_name: str, tags: Optional[Dict[str, Any]] = None) -> I
replies = span._data.get("haystack.component.output", {}).get("replies")
if replies:
meta = replies[0].meta
span._span.update(usage=meta.get("usage") or None, model=meta.get("model"))
completion_start_time = meta.get("completion_start_time")
if completion_start_time:
try:
completion_start_time = datetime.fromisoformat(completion_start_time)
except ValueError:
logger.error(f"Failed to parse completion_start_time: {completion_start_time}")
completion_start_time = None
span._span.update(
usage=meta.get("usage") or None,
model=meta.get("model"),
completion_start_time=completion_start_time,
)

pipeline_input = tags.get("haystack.pipeline.input_data", None)
if pipeline_input:
Expand Down
101 changes: 69 additions & 32 deletions integrations/langfuse/tests/test_tracer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,43 @@
import os
import datetime
from unittest.mock import MagicMock, Mock, patch

from haystack.dataclasses import ChatMessage
from haystack_integrations.tracing.langfuse.tracer import LangfuseTracer


class MockSpan:
def __init__(self):
self._data = {}
self._span = self
self.operation_name = "operation_name"

def raw_span(self):
return self

def span(self, name=None):
# assert correct operation name passed to the span
assert name == "operation_name"
return self

def update(self, **kwargs):
self._data.update(kwargs)

def generation(self, name=None):
return self

def end(self):
pass


class MockTracer:

def trace(self, name, **kwargs):
return MockSpan()

def flush(self):
pass


class TestLangfuseTracer:

# LangfuseTracer can be initialized with a Langfuse instance, a name and a boolean value for public.
Expand Down Expand Up @@ -45,44 +79,47 @@ def test_create_new_span(self):

# check that update method is called on the span instance with the provided key value pairs
def test_update_span_with_pipeline_input_output_data(self):
class MockTracer:

def trace(self, name, **kwargs):
return MockSpan()

def flush(self):
pass

class MockSpan:
def __init__(self):
self._data = {}
self._span = self
self.operation_name = "operation_name"

def raw_span(self):
return self

def span(self, name=None):
# assert correct operation name passed to the span
assert name == "operation_name"
return self

def update(self, **kwargs):
self._data.update(kwargs)

def generation(self, name=None):
return self

def end(self):
pass

tracer = LangfuseTracer(tracer=MockTracer(), name="Haystack", public=False)
with tracer.trace(operation_name="operation_name", tags={"haystack.pipeline.input_data": "hello"}) as span:
assert span.raw_span()._data["metadata"] == {"haystack.pipeline.input_data": "hello"}

with tracer.trace(operation_name="operation_name", tags={"haystack.pipeline.output_data": "bye"}) as span:
assert span.raw_span()._data["metadata"] == {"haystack.pipeline.output_data": "bye"}

def test_trace_generation(self):
tracer = LangfuseTracer(tracer=MockTracer(), name="Haystack", public=False)
tags = {
"haystack.component.type": "OpenAIChatGenerator",
"haystack.component.output": {
"replies": [
ChatMessage.from_assistant(
"", meta={"completion_start_time": "2021-07-27T16:02:08.012345", "model": "test_model"}
)
]
},
}
with tracer.trace(operation_name="operation_name", tags=tags) as span:
...
assert span.raw_span()._data["usage"] is None
assert span.raw_span()._data["model"] == "test_model"
assert span.raw_span()._data["completion_start_time"] == datetime.datetime(2021, 7, 27, 16, 2, 8, 12345)

def test_trace_generation_invalid_start_time(self):
tracer = LangfuseTracer(tracer=MockTracer(), name="Haystack", public=False)
tags = {
"haystack.component.type": "OpenAIChatGenerator",
"haystack.component.output": {
"replies": [
ChatMessage.from_assistant("", meta={"completion_start_time": "foobar", "model": "test_model"}),
]
},
}
with tracer.trace(operation_name="operation_name", tags=tags) as span:
...
assert span.raw_span()._data["usage"] is None
assert span.raw_span()._data["model"] == "test_model"
assert span.raw_span()._data["completion_start_time"] is None

def test_update_span_gets_flushed_by_default(self):
tracer_mock = Mock()

Expand Down