diff --git a/haystack/components/generators/chat/openai.py b/haystack/components/generators/chat/openai.py index 8f54b7ad36..2de4e5b1ba 100644 --- a/haystack/components/generators/chat/openai.py +++ b/haystack/components/generators/chat/openai.py @@ -5,6 +5,7 @@ import copy import json import os +from datetime import datetime from typing import Any, Callable, Dict, List, Optional, Union from openai import OpenAI, Stream @@ -222,11 +223,15 @@ def run( raise ValueError("Cannot stream multiple responses, please set n=1.") chunks: List[StreamingChunk] = [] chunk = None + _first_token = True # pylint: disable=not-an-iterable for chunk in chat_completion: if chunk.choices and streaming_callback: chunk_delta: StreamingChunk = self._build_chunk(chunk) + if _first_token: + _first_token = False + chunk_delta.meta["completion_start_time"] = datetime.now().isoformat() chunks.append(chunk_delta) streaming_callback(chunk_delta) # invoke callback with the chunk_delta completions = [self._connect_chunks(chunk, chunks)] @@ -280,7 +285,12 @@ def _connect_chunks(self, chunk: Any, chunks: List[StreamingChunk]) -> ChatMessa payload["function"]["arguments"] += delta.arguments or "" complete_response = ChatMessage.from_assistant(json.dumps(payloads)) else: - complete_response = ChatMessage.from_assistant("".join([chunk.content for chunk in chunks])) + total_content = "" + total_meta = {} + for streaming_chunk in chunks: + total_content += streaming_chunk.content + total_meta.update(streaming_chunk.meta) + complete_response = ChatMessage.from_assistant(total_content, meta=total_meta) complete_response.meta.update( { "model": chunk.model, diff --git a/releasenotes/notes/openai-ttft-42b1ad551b542930.yaml b/releasenotes/notes/openai-ttft-42b1ad551b542930.yaml new file mode 100644 index 0000000000..bc39b96f7f --- /dev/null +++ b/releasenotes/notes/openai-ttft-42b1ad551b542930.yaml @@ -0,0 +1,6 @@ +--- +features: + - | + Add TTFT (Time-to-First-Token) support for OpenAI generators. This + captures the time taken to generate the first token from the model and + can be used to analyze the latency of the application. diff --git a/test/components/generators/chat/test_openai.py b/test/components/generators/chat/test_openai.py index c3332aae97..c74fd73dff 100644 --- a/test/components/generators/chat/test_openai.py +++ b/test/components/generators/chat/test_openai.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 import logging import os +from unittest.mock import patch import pytest from openai import OpenAIError @@ -219,7 +220,8 @@ def streaming_callback(chunk: StreamingChunk) -> None: assert [isinstance(reply, ChatMessage) for reply in response["replies"]] assert "Hello" in response["replies"][0].content # see mock_chat_completion_chunk - def test_run_with_streaming_callback_in_run_method(self, chat_messages, mock_chat_completion_chunk): + @patch("haystack.components.generators.chat.openai.datetime") + def test_run_with_streaming_callback_in_run_method(self, mock_datetime, mock_chat_completion_chunk): streaming_callback_called = False def streaming_callback(chunk: StreamingChunk) -> None: @@ -240,6 +242,12 @@ def streaming_callback(chunk: StreamingChunk) -> None: assert [isinstance(reply, ChatMessage) for reply in response["replies"]] assert "Hello" in response["replies"][0].content # see mock_chat_completion_chunk + assert "meta" in response + assert isinstance(response["meta"], list) + assert len(response["meta"]) == 1 + assert isinstance(response["meta"][0], dict) + assert response["meta"][0]["completion_start_time"] == mock_datetime.now.return_value.isoformat.return_value + def test_check_abnormal_completions(self, caplog): caplog.set_level(logging.INFO) component = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key")) diff --git a/test/components/generators/test_openai.py b/test/components/generators/test_openai.py index 047299b199..56e3733caa 100644 --- a/test/components/generators/test_openai.py +++ b/test/components/generators/test_openai.py @@ -4,6 +4,7 @@ import logging import os from typing import List +from unittest.mock import patch import pytest from openai import OpenAIError