Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add tests for VertexAIGeminiGenerator and enable streaming #1012

Merged
merged 20 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import logging
from typing import Any, Dict, List, Optional, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Union

import vertexai
from haystack.core.component import component
from haystack.core.component.types import Variadic
from haystack.core.serialization import default_from_dict, default_to_dict
from haystack.dataclasses.byte_stream import ByteStream
from vertexai.preview.generative_models import (
from haystack.dataclasses import ByteStream, StreamingChunk
from haystack.utils import deserialize_callable, serialize_callable
from vertexai import init as vertexai_init
from vertexai.generative_models import (
Content,
FunctionDeclaration,
GenerationConfig,
GenerationResponse,
GenerativeModel,
HarmBlockThreshold,
HarmCategory,
Expand Down Expand Up @@ -60,6 +61,7 @@ def __init__(
generation_config: Optional[Union[GenerationConfig, Dict[str, Any]]] = None,
safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None,
tools: Optional[List[Tool]] = None,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
):
"""
Multi-modal generator using Gemini model via Google Vertex AI.
Expand Down Expand Up @@ -87,10 +89,12 @@ def __init__(
:param tools: List of tools to use when generating content. See the documentation for
[Tool](https://cloud.google.com/python/docs/reference/aiplatform/latest/vertexai.preview.generative_models.Tool)
the list of supported arguments.
:param streaming_callback: A callback function that is called when a new token is received from the stream.
The callback function accepts StreamingChunk as an argument.
"""

# Login to GCP. This will fail if user has not set up their gcloud SDK
vertexai.init(project=project_id, location=location)
vertexai_init(project=project_id, location=location)

self._model_name = model
self._project_id = project_id
Expand All @@ -100,18 +104,7 @@ def __init__(
self._generation_config = generation_config
self._safety_settings = safety_settings
self._tools = tools

def _function_to_dict(self, function: FunctionDeclaration) -> Dict[str, Any]:
return {
"name": function._raw_function_declaration.name,
"parameters": function._raw_function_declaration.parameters,
"description": function._raw_function_declaration.description,
}

def _tool_to_dict(self, tool: Tool) -> Dict[str, Any]:
return {
"function_declarations": [self._function_to_dict(f) for f in tool._raw_tool.function_declarations],
}
self._streaming_callback = streaming_callback

def _generation_config_to_dict(self, config: Union[GenerationConfig, Dict[str, Any]]) -> Dict[str, Any]:
if isinstance(config, dict):
Expand All @@ -132,6 +125,8 @@ def to_dict(self) -> Dict[str, Any]:
:returns:
Dictionary with serialized data.
"""

callback_name = serialize_callable(self._streaming_callback) if self._streaming_callback else None
data = default_to_dict(
self,
model=self._model_name,
Expand All @@ -140,9 +135,10 @@ def to_dict(self) -> Dict[str, Any]:
generation_config=self._generation_config,
safety_settings=self._safety_settings,
tools=self._tools,
streaming_callback=callback_name,
)
if (tools := data["init_parameters"].get("tools")) is not None:
data["init_parameters"]["tools"] = [self._tool_to_dict(t) for t in tools]
data["init_parameters"]["tools"] = [Tool.to_dict(t) for t in tools]
Amnah199 marked this conversation as resolved.
Show resolved Hide resolved
if (generation_config := data["init_parameters"].get("generation_config")) is not None:
data["init_parameters"]["generation_config"] = self._generation_config_to_dict(generation_config)
return data
Expand All @@ -161,7 +157,8 @@ def from_dict(cls, data: Dict[str, Any]) -> "VertexAIGeminiGenerator":
data["init_parameters"]["tools"] = [Tool.from_dict(t) for t in tools]
if (generation_config := data["init_parameters"].get("generation_config")) is not None:
data["init_parameters"]["generation_config"] = GenerationConfig.from_dict(generation_config)

if (serialized_callback_handler := data["init_parameters"].get("streaming_callback")) is not None:
data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler)
return default_from_dict(cls, data)

def _convert_part(self, part: Union[str, ByteStream, Part]) -> Part:
Expand All @@ -176,14 +173,21 @@ def _convert_part(self, part: Union[str, ByteStream, Part]) -> Part:
raise ValueError(msg)

@component.output_types(replies=List[Union[str, Dict[str, str]]])
def run(self, parts: Variadic[Union[str, ByteStream, Part]]):
def run(
self,
parts: Variadic[Union[str, ByteStream, Part]],
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
):
"""
Generates content using the Gemini model.

:param parts: Prompt for the model.
:param streaming_callback: A callback function that is called when a new token is received from the stream.
:returns: A dictionary with the following keys:
- `replies`: A list of generated content.
"""
# check if streaming_callback is passed
streaming_callback = streaming_callback or self._streaming_callback
converted_parts = [self._convert_part(p) for p in parts]

contents = [Content(parts=converted_parts, role="user")]
Expand All @@ -192,10 +196,23 @@ def run(self, parts: Variadic[Union[str, ByteStream, Part]]):
generation_config=self._generation_config,
safety_settings=self._safety_settings,
tools=self._tools,
stream=streaming_callback is not None,
)
self._model.start_chat()
replies = self._get_stream_response(res, streaming_callback) if streaming_callback else self._get_response(res)

return {"replies": replies}

def _get_response(self, response_body: GenerationResponse) -> List[str]:
"""
Extracts the responses from the Vertex AI response.

:param response_body: The response body from the Vertex AI request.

:returns: A list of string responses.
"""
replies = []
for candidate in res.candidates:
for candidate in response_body.candidates:
for part in candidate.content.parts:
if part._raw_part.text != "":
replies.append(part.text)
Expand All @@ -205,5 +222,24 @@ def run(self, parts: Variadic[Union[str, ByteStream, Part]]):
"args": dict(part.function_call.args.items()),
}
replies.append(function_call)
return replies

return {"replies": replies}
def _get_stream_response(
self, stream: Iterable[GenerationResponse], streaming_callback: Callable[[StreamingChunk], None]
) -> List[str]:
"""
Extracts the responses from the Vertex AI streaming response.

:param stream: The streaming response from the Vertex AI request.
:param streaming_callback: The handler for the streaming response.
:returns: A list of string responses.
"""
streaming_chunks: List[StreamingChunk] = []

for chunk in stream:
streaming_chunk = StreamingChunk(content=chunk.text, meta=chunk.to_dict())
streaming_chunks.append(streaming_chunk)
streaming_callback(streaming_chunk)

responses = ["".join(streaming_chunk.content for streaming_chunk in streaming_chunks).lstrip()]
return responses
Loading