Skip to content

Commit

Permalink
Add streaming_callback in init
Browse files Browse the repository at this point in the history
  • Loading branch information
Amnah199 committed Aug 21, 2024
1 parent d409b24 commit c414ab0
Showing 1 changed file with 43 additions and 4 deletions.
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import logging
from typing import Any, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union

import vertexai
from haystack.core.component import component
from haystack.core.serialization import default_from_dict, default_to_dict
from haystack.dataclasses import StreamingChunk
from haystack.dataclasses.byte_stream import ByteStream
from haystack.dataclasses.chat_message import ChatMessage, ChatRole
from vertexai.preview.generative_models import (
Expand Down Expand Up @@ -55,6 +56,7 @@ def __init__(
generation_config: Optional[Union[GenerationConfig, Dict[str, Any]]] = None,
safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None,
tools: Optional[List[Tool]] = None,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
):
"""
`VertexAIGeminiChatGenerator` enables chat completion using Google Gemini models.
Expand All @@ -76,6 +78,9 @@ def __init__(
:param tools: List of tools to use when generating content. See the documentation for
[Tool](https://cloud.google.com/python/docs/reference/aiplatform/latest/vertexai.preview.generative_models.Tool)
the list of supported arguments.
:param streaming_callback: A callback function that is called when a new token is received from
the stream. The callback function accepts StreamingChunk as an argument.
"""

# Login to GCP. This will fail if user has not set up their gcloud SDK
Expand All @@ -89,6 +94,7 @@ def __init__(
self._generation_config = generation_config
self._safety_settings = safety_settings
self._tools = tools
self._streaming_callback = streaming_callback

def _function_to_dict(self, function: FunctionDeclaration) -> Dict[str, Any]:
return {
Expand Down Expand Up @@ -129,9 +135,10 @@ def to_dict(self) -> Dict[str, Any]:
generation_config=self._generation_config,
safety_settings=self._safety_settings,
tools=self._tools,
streaming_callback=self._streaming_callback,
)
if (tools := data["init_parameters"].get("tools")) is not None:
data["init_parameters"]["tools"] = [self._tool_to_dict(t) for t in tools]
data["init_parameters"]["tools"] = [Tool.to_dict(t) for t in tools]
if (generation_config := data["init_parameters"].get("generation_config")) is not None:
data["init_parameters"]["generation_config"] = self._generation_config_to_dict(generation_config)
return data
Expand Down Expand Up @@ -211,10 +218,26 @@ def run(self, messages: List[ChatMessage]):
generation_config=self._generation_config,
safety_settings=self._safety_settings,
tools=self._tools,
stream=self._streaming_callback is not None,
)

replies = (
self.get_stream_response(res, self._streaming_callback)
if self._streaming_callback
else self.get_response(res)
)

return {"replies": replies}

def get_response(self, response_body) -> List[ChatMessage]:
"""
Extracts the responses from the Vertex AI response.
:param response_body: The response from VertexAi request.
:returns: The extracted responses.
"""
replies = []
for candidate in res.candidates:
for candidate in response_body.candidates:
for part in candidate.content.parts:
if part._raw_part.text != "":
replies.append(ChatMessage.from_system(part.text))
Expand All @@ -226,5 +249,21 @@ def run(self, messages: List[ChatMessage]):
name=part.function_call.name,
)
)
return replies

return {"replies": replies}
def get_stream_response(self, stream, streaming_callback: Callable[[StreamingChunk], None]) -> List[ChatMessage]:
"""
Extracts the responses from the Vertex AI streaming response.
:param stream: The streaming response from the Vertex AI request.
:param streaming_callback: The handler for the streaming response.
:returns: The extracted response with the content of all streaming chunks.
"""
responses = []
for chunk in stream:
streaming_chunk = StreamingChunk(content=chunk.text, meta=chunk.usage_metadata)
streaming_callback(streaming_chunk)
responses.append(streaming_chunk.content)

combined_response = "".join(responses).lstrip()
return [ChatMessage.from_system(content=combined_response)]

0 comments on commit c414ab0

Please sign in to comment.