Skip to content

Commit

Permalink
Add stream param to run
Browse files Browse the repository at this point in the history
  • Loading branch information
Amnah199 committed Aug 21, 2024
1 parent c414ab0 commit 00b2f31
Showing 1 changed file with 9 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from haystack.dataclasses.chat_message import ChatMessage, ChatRole
from vertexai.preview.generative_models import (
Content,
FunctionDeclaration,
GenerationConfig,
GenerativeModel,
HarmBlockThreshold,
Expand Down Expand Up @@ -96,18 +95,6 @@ def __init__(
self._tools = tools
self._streaming_callback = streaming_callback

def _function_to_dict(self, function: FunctionDeclaration) -> Dict[str, Any]:
return {
"name": function._raw_function_declaration.name,
"parameters": function._raw_function_declaration.parameters,
"description": function._raw_function_declaration.description,
}

def _tool_to_dict(self, tool: Tool) -> Dict[str, Any]:
return {
"function_declarations": [self._function_to_dict(f) for f in tool._raw_tool.function_declarations],
}

def _generation_config_to_dict(self, config: Union[GenerationConfig, Dict[str, Any]]) -> Dict[str, Any]:
if isinstance(config, dict):
return config
Expand Down Expand Up @@ -202,13 +189,21 @@ def _message_to_content(self, message: ChatMessage) -> Content:
return Content(parts=[part], role=role)

@component.output_types(replies=List[ChatMessage])
def run(self, messages: List[ChatMessage]):
def run(
self,
messages: List[ChatMessage],
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
):
"""Prompts Google Vertex AI Gemini model to generate a response to a list of messages.
:param messages: The last message is the prompt, the rest are the history.
:param streaming_callback: A callback function that is called when a new token is received from the stream.
:returns: A dictionary with the following keys:
- `replies`: A list of ChatMessage objects representing the model's replies.
"""
# check if streaming_callback is passed
streaming_callback = streaming_callback or self._streaming_callback

history = [self._message_to_content(m) for m in messages[:-1]]
session = self._model.start_chat(history=history)

Expand Down

0 comments on commit 00b2f31

Please sign in to comment.