Skip to content

Commit

Permalink
feat: enable streaming for VertexAIGeminiChatGenerator (#1014)
Browse files Browse the repository at this point in the history
* Add streaming_callback in init() and run()
  • Loading branch information
Amnah199 authored Aug 27, 2024
1 parent ee08a47 commit e2d566a
Showing 1 changed file with 58 additions and 21 deletions.
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import logging
from typing import Any, Dict, List, Optional, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Union

import vertexai
from haystack.core.component import component
from haystack.core.serialization import default_from_dict, default_to_dict
from haystack.dataclasses import StreamingChunk
from haystack.dataclasses.byte_stream import ByteStream
from haystack.dataclasses.chat_message import ChatMessage, ChatRole
from haystack.utils import deserialize_callable, serialize_callable
from vertexai import init as vertexai_init
from vertexai.preview.generative_models import (
Content,
FunctionDeclaration,
GenerationConfig,
GenerationResponse,
GenerativeModel,
HarmBlockThreshold,
HarmCategory,
Expand Down Expand Up @@ -55,6 +57,7 @@ def __init__(
generation_config: Optional[Union[GenerationConfig, Dict[str, Any]]] = None,
safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None,
tools: Optional[List[Tool]] = None,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
):
"""
`VertexAIGeminiChatGenerator` enables chat completion using Google Gemini models.
Expand All @@ -76,10 +79,13 @@ def __init__(
:param tools: List of tools to use when generating content. See the documentation for
[Tool](https://cloud.google.com/python/docs/reference/aiplatform/latest/vertexai.preview.generative_models.Tool)
the list of supported arguments.
:param streaming_callback: A callback function that is called when a new token is received from
the stream. The callback function accepts StreamingChunk as an argument.
"""

# Login to GCP. This will fail if user has not set up their gcloud SDK
vertexai.init(project=project_id, location=location)
vertexai_init(project=project_id, location=location)

self._model_name = model
self._project_id = project_id
Expand All @@ -89,18 +95,7 @@ def __init__(
self._generation_config = generation_config
self._safety_settings = safety_settings
self._tools = tools

def _function_to_dict(self, function: FunctionDeclaration) -> Dict[str, Any]:
return {
"name": function._raw_function_declaration.name,
"parameters": function._raw_function_declaration.parameters,
"description": function._raw_function_declaration.description,
}

def _tool_to_dict(self, tool: Tool) -> Dict[str, Any]:
return {
"function_declarations": [self._function_to_dict(f) for f in tool._raw_tool.function_declarations],
}
self._streaming_callback = streaming_callback

def _generation_config_to_dict(self, config: Union[GenerationConfig, Dict[str, Any]]) -> Dict[str, Any]:
if isinstance(config, dict):
Expand All @@ -121,6 +116,8 @@ def to_dict(self) -> Dict[str, Any]:
:returns:
Dictionary with serialized data.
"""
callback_name = serialize_callable(self._streaming_callback) if self._streaming_callback else None

data = default_to_dict(
self,
model=self._model_name,
Expand All @@ -129,9 +126,10 @@ def to_dict(self) -> Dict[str, Any]:
generation_config=self._generation_config,
safety_settings=self._safety_settings,
tools=self._tools,
streaming_callback=callback_name,
)
if (tools := data["init_parameters"].get("tools")) is not None:
data["init_parameters"]["tools"] = [self._tool_to_dict(t) for t in tools]
data["init_parameters"]["tools"] = [Tool.to_dict(t) for t in tools]
if (generation_config := data["init_parameters"].get("generation_config")) is not None:
data["init_parameters"]["generation_config"] = self._generation_config_to_dict(generation_config)
return data
Expand All @@ -150,7 +148,8 @@ def from_dict(cls, data: Dict[str, Any]) -> "VertexAIGeminiChatGenerator":
data["init_parameters"]["tools"] = [Tool.from_dict(t) for t in tools]
if (generation_config := data["init_parameters"].get("generation_config")) is not None:
data["init_parameters"]["generation_config"] = GenerationConfig.from_dict(generation_config)

if (serialized_callback_handler := data["init_parameters"].get("streaming_callback")) is not None:
data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler)
return default_from_dict(cls, data)

def _convert_part(self, part: Union[str, ByteStream, Part]) -> Part:
Expand Down Expand Up @@ -195,13 +194,21 @@ def _message_to_content(self, message: ChatMessage) -> Content:
return Content(parts=[part], role=role)

@component.output_types(replies=List[ChatMessage])
def run(self, messages: List[ChatMessage]):
def run(
self,
messages: List[ChatMessage],
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
):
"""Prompts Google Vertex AI Gemini model to generate a response to a list of messages.
:param messages: The last message is the prompt, the rest are the history.
:param streaming_callback: A callback function that is called when a new token is received from the stream.
:returns: A dictionary with the following keys:
- `replies`: A list of ChatMessage objects representing the model's replies.
"""
# check if streaming_callback is passed
streaming_callback = streaming_callback or self._streaming_callback

history = [self._message_to_content(m) for m in messages[:-1]]
session = self._model.start_chat(history=history)

Expand All @@ -211,10 +218,22 @@ def run(self, messages: List[ChatMessage]):
generation_config=self._generation_config,
safety_settings=self._safety_settings,
tools=self._tools,
stream=streaming_callback is not None,
)

replies = self._get_stream_response(res, streaming_callback) if streaming_callback else self._get_response(res)

return {"replies": replies}

def _get_response(self, response_body: GenerationResponse) -> List[ChatMessage]:
"""
Extracts the responses from the Vertex AI response.
:param response_body: The response from Vertex AI request.
:returns: The extracted responses.
"""
replies = []
for candidate in res.candidates:
for candidate in response_body.candidates:
for part in candidate.content.parts:
if part._raw_part.text != "":
replies.append(ChatMessage.from_system(part.text))
Expand All @@ -226,5 +245,23 @@ def run(self, messages: List[ChatMessage]):
name=part.function_call.name,
)
)
return replies

return {"replies": replies}
def _get_stream_response(
self, stream: Iterable[GenerationResponse], streaming_callback: Callable[[StreamingChunk], None]
) -> List[ChatMessage]:
"""
Extracts the responses from the Vertex AI streaming response.
:param stream: The streaming response from the Vertex AI request.
:param streaming_callback: The handler for the streaming response.
:returns: The extracted response with the content of all streaming chunks.
"""
responses = []
for chunk in stream:
streaming_chunk = StreamingChunk(content=chunk.text, meta=chunk.to_dict())
streaming_callback(streaming_chunk)
responses.append(streaming_chunk.content)

combined_response = "".join(responses).lstrip()
return [ChatMessage.from_system(content=combined_response)]

0 comments on commit e2d566a

Please sign in to comment.