Skip to content

Commit

Permalink
Fix GoogleAIGeminiChatGenerator to_dict and from_dict
Browse files Browse the repository at this point in the history
  • Loading branch information
silvanocerza committed Jul 9, 2024
1 parent cfc1cee commit 3f1bce7
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
from typing import Any, Dict, List, Optional, Union

import google.generativeai as genai
from google.ai.generativelanguage import Content, Part, Tool
from google.ai.generativelanguage import Content, Part
from google.ai.generativelanguage import Tool as ToolProto
from google.generativeai import GenerationConfig, GenerativeModel
from google.generativeai.types import HarmBlockThreshold, HarmCategory
from google.generativeai.types import HarmBlockThreshold, HarmCategory, Tool
from haystack.core.component import component
from haystack.core.serialization import default_from_dict, default_to_dict
from haystack.dataclasses.byte_stream import ByteStream
Expand Down Expand Up @@ -159,7 +160,14 @@ def to_dict(self) -> Dict[str, Any]:
tools=self._tools,
)
if (tools := data["init_parameters"].get("tools")) is not None:
data["init_parameters"]["tools"] = [Tool.serialize(t) for t in tools]
data["init_parameters"]["tools"] = []
for tool in tools:
if isinstance(tool, Tool):
# There are multiple Tool types in the Google lib, one that is a protobuf class and
# another is a simple Python class. They have a similar structure but the Python class
# can't be easily serializated to a dict. We need to convert it to a protobuf class first.
tool = tool.to_proto() # noqa: PLW2901
data["init_parameters"]["tools"].append(ToolProto.serialize(tool))
if (generation_config := data["init_parameters"].get("generation_config")) is not None:
data["init_parameters"]["generation_config"] = self._generation_config_to_dict(generation_config)
if (safety_settings := data["init_parameters"].get("safety_settings")) is not None:
Expand All @@ -179,7 +187,15 @@ def from_dict(cls, data: Dict[str, Any]) -> "GoogleAIGeminiChatGenerator":
deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"])

if (tools := data["init_parameters"].get("tools")) is not None:
data["init_parameters"]["tools"] = [Tool.deserialize(t) for t in tools]
deserialized_tools = []
for tool in tools:
# Tools are always serialized as a protobuf class, so we need to deserialize them first
# to be able to convert them to the Python class.
proto = ToolProto.deserialize(tool)
deserialized_tools.append(
Tool(function_declarations=proto.function_declarations, code_execution=proto.code_execution)
)
data["init_parameters"]["tools"] = deserialized_tools
if (generation_config := data["init_parameters"].get("generation_config")) is not None:
data["init_parameters"]["generation_config"] = GenerationConfig(**generation_config)
if (safety_settings := data["init_parameters"].get("safety_settings")) is not None:
Expand Down
37 changes: 10 additions & 27 deletions integrations/google_ai/tests/generators/chat/test_chat_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,33 +157,16 @@ def test_from_dict(monkeypatch):
top_k=2,
)
assert gemini._safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH}
assert gemini._tools == [
Tool(
function_declarations=[
FunctionDeclaration(
name="get_current_weather",
description="Get the current weather in a given location",
parameters={
"type_": "OBJECT",
"properties": {
"location": {
"type_": "STRING",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {
"type_": "STRING",
"enum": [
"celsius",
"fahrenheit",
],
},
},
"required": ["location"],
},
)
]
)
]
assert len(gemini._tools) == 1
assert len(gemini._tools[0].function_declarations) == 1
assert gemini._tools[0].function_declarations[0].name == "get_current_weather"
assert gemini._tools[0].function_declarations[0].description == "Get the current weather in a given location"
assert (
gemini._tools[0].function_declarations[0].parameters.properties["location"].description
== "The city and state, e.g. San Francisco, CA"
)
assert gemini._tools[0].function_declarations[0].parameters.properties["unit"].enum == ["celsius", "fahrenheit"]
assert gemini._tools[0].function_declarations[0].parameters.required == ["location"]
assert isinstance(gemini._model, GenerativeModel)


Expand Down

0 comments on commit 3f1bce7

Please sign in to comment.