From 3f1bce7c76ed50e5569f7e81344f37443e437f4a Mon Sep 17 00:00:00 2001 From: Silvano Cerza Date: Tue, 9 Jul 2024 10:55:07 +0200 Subject: [PATCH] Fix GoogleAIGeminiChatGenerator to_dict and from_dict --- .../generators/google_ai/chat/gemini.py | 24 ++++++++++-- .../tests/generators/chat/test_chat_gemini.py | 37 +++++-------------- 2 files changed, 30 insertions(+), 31 deletions(-) diff --git a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py index d3a8299fd..8b592a184 100644 --- a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py +++ b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py @@ -2,9 +2,10 @@ from typing import Any, Dict, List, Optional, Union import google.generativeai as genai -from google.ai.generativelanguage import Content, Part, Tool +from google.ai.generativelanguage import Content, Part +from google.ai.generativelanguage import Tool as ToolProto from google.generativeai import GenerationConfig, GenerativeModel -from google.generativeai.types import HarmBlockThreshold, HarmCategory +from google.generativeai.types import HarmBlockThreshold, HarmCategory, Tool from haystack.core.component import component from haystack.core.serialization import default_from_dict, default_to_dict from haystack.dataclasses.byte_stream import ByteStream @@ -159,7 +160,14 @@ def to_dict(self) -> Dict[str, Any]: tools=self._tools, ) if (tools := data["init_parameters"].get("tools")) is not None: - data["init_parameters"]["tools"] = [Tool.serialize(t) for t in tools] + data["init_parameters"]["tools"] = [] + for tool in tools: + if isinstance(tool, Tool): + # There are multiple Tool types in the Google lib, one that is a protobuf class and + # another is a simple Python class. They have a similar structure but the Python class + # can't be easily serializated to a dict. We need to convert it to a protobuf class first. + tool = tool.to_proto() # noqa: PLW2901 + data["init_parameters"]["tools"].append(ToolProto.serialize(tool)) if (generation_config := data["init_parameters"].get("generation_config")) is not None: data["init_parameters"]["generation_config"] = self._generation_config_to_dict(generation_config) if (safety_settings := data["init_parameters"].get("safety_settings")) is not None: @@ -179,7 +187,15 @@ def from_dict(cls, data: Dict[str, Any]) -> "GoogleAIGeminiChatGenerator": deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"]) if (tools := data["init_parameters"].get("tools")) is not None: - data["init_parameters"]["tools"] = [Tool.deserialize(t) for t in tools] + deserialized_tools = [] + for tool in tools: + # Tools are always serialized as a protobuf class, so we need to deserialize them first + # to be able to convert them to the Python class. + proto = ToolProto.deserialize(tool) + deserialized_tools.append( + Tool(function_declarations=proto.function_declarations, code_execution=proto.code_execution) + ) + data["init_parameters"]["tools"] = deserialized_tools if (generation_config := data["init_parameters"].get("generation_config")) is not None: data["init_parameters"]["generation_config"] = GenerationConfig(**generation_config) if (safety_settings := data["init_parameters"].get("safety_settings")) is not None: diff --git a/integrations/google_ai/tests/generators/chat/test_chat_gemini.py b/integrations/google_ai/tests/generators/chat/test_chat_gemini.py index 03bcf62c5..9b3124eab 100644 --- a/integrations/google_ai/tests/generators/chat/test_chat_gemini.py +++ b/integrations/google_ai/tests/generators/chat/test_chat_gemini.py @@ -157,33 +157,16 @@ def test_from_dict(monkeypatch): top_k=2, ) assert gemini._safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} - assert gemini._tools == [ - Tool( - function_declarations=[ - FunctionDeclaration( - name="get_current_weather", - description="Get the current weather in a given location", - parameters={ - "type_": "OBJECT", - "properties": { - "location": { - "type_": "STRING", - "description": "The city and state, e.g. San Francisco, CA", - }, - "unit": { - "type_": "STRING", - "enum": [ - "celsius", - "fahrenheit", - ], - }, - }, - "required": ["location"], - }, - ) - ] - ) - ] + assert len(gemini._tools) == 1 + assert len(gemini._tools[0].function_declarations) == 1 + assert gemini._tools[0].function_declarations[0].name == "get_current_weather" + assert gemini._tools[0].function_declarations[0].description == "Get the current weather in a given location" + assert ( + gemini._tools[0].function_declarations[0].parameters.properties["location"].description + == "The city and state, e.g. San Francisco, CA" + ) + assert gemini._tools[0].function_declarations[0].parameters.properties["unit"].enum == ["celsius", "fahrenheit"] + assert gemini._tools[0].function_declarations[0].parameters.required == ["location"] assert isinstance(gemini._model, GenerativeModel)