Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Google AI tests failing #885

Merged
merged 2 commits into from
Jul 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
from typing import Any, Dict, List, Optional, Union

import google.generativeai as genai
from google.ai.generativelanguage import Content, Part, Tool
from google.ai.generativelanguage import Content, Part
from google.ai.generativelanguage import Tool as ToolProto
silvanocerza marked this conversation as resolved.
Show resolved Hide resolved
from google.generativeai import GenerationConfig, GenerativeModel
from google.generativeai.types import HarmBlockThreshold, HarmCategory
from google.generativeai.types import HarmBlockThreshold, HarmCategory, Tool
from haystack.core.component import component
from haystack.core.serialization import default_from_dict, default_to_dict
from haystack.dataclasses.byte_stream import ByteStream
Expand Down Expand Up @@ -159,7 +160,14 @@ def to_dict(self) -> Dict[str, Any]:
tools=self._tools,
)
if (tools := data["init_parameters"].get("tools")) is not None:
data["init_parameters"]["tools"] = [Tool.serialize(t) for t in tools]
data["init_parameters"]["tools"] = []
for tool in tools:
if isinstance(tool, Tool):
# There are multiple Tool types in the Google lib, one that is a protobuf class and
# another is a simple Python class. They have a similar structure but the Python class
# can't be easily serializated to a dict. We need to convert it to a protobuf class first.
tool = tool.to_proto() # noqa: PLW2901
data["init_parameters"]["tools"].append(ToolProto.serialize(tool))
if (generation_config := data["init_parameters"].get("generation_config")) is not None:
data["init_parameters"]["generation_config"] = self._generation_config_to_dict(generation_config)
if (safety_settings := data["init_parameters"].get("safety_settings")) is not None:
Expand All @@ -179,7 +187,15 @@ def from_dict(cls, data: Dict[str, Any]) -> "GoogleAIGeminiChatGenerator":
deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"])

if (tools := data["init_parameters"].get("tools")) is not None:
data["init_parameters"]["tools"] = [Tool.deserialize(t) for t in tools]
deserialized_tools = []
silvanocerza marked this conversation as resolved.
Show resolved Hide resolved
for tool in tools:
# Tools are always serialized as a protobuf class, so we need to deserialize them first
# to be able to convert them to the Python class.
proto = ToolProto.deserialize(tool)
deserialized_tools.append(
Tool(function_declarations=proto.function_declarations, code_execution=proto.code_execution)
)
data["init_parameters"]["tools"] = deserialized_tools
if (generation_config := data["init_parameters"].get("generation_config")) is not None:
data["init_parameters"]["generation_config"] = GenerationConfig(**generation_config)
if (safety_settings := data["init_parameters"].get("safety_settings")) is not None:
Expand Down
61 changes: 16 additions & 45 deletions integrations/google_ai/tests/generators/chat/test_chat_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
from unittest.mock import patch

import pytest
from google.ai.generativelanguage import FunctionDeclaration, Tool
from google.generativeai import GenerationConfig, GenerativeModel
from google.generativeai.types import HarmBlockThreshold, HarmCategory
from google.generativeai.types import FunctionDeclaration, HarmBlockThreshold, HarmCategory, Tool
from haystack.dataclasses.chat_message import ChatMessage

from haystack_integrations.components.generators.google_ai import GoogleAIGeminiChatGenerator
Expand Down Expand Up @@ -158,33 +157,16 @@ def test_from_dict(monkeypatch):
top_k=2,
)
assert gemini._safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH}
assert gemini._tools == [
Tool(
function_declarations=[
FunctionDeclaration(
name="get_current_weather",
description="Get the current weather in a given location",
parameters={
"type_": "OBJECT",
"properties": {
"location": {
"type_": "STRING",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {
"type_": "STRING",
"enum": [
"celsius",
"fahrenheit",
],
},
},
"required": ["location"],
},
)
]
)
]
assert len(gemini._tools) == 1
assert len(gemini._tools[0].function_declarations) == 1
assert gemini._tools[0].function_declarations[0].name == "get_current_weather"
assert gemini._tools[0].function_declarations[0].description == "Get the current weather in a given location"
assert (
gemini._tools[0].function_declarations[0].parameters.properties["location"].description
== "The city and state, e.g. San Francisco, CA"
)
assert gemini._tools[0].function_declarations[0].parameters.properties["unit"].enum == ["celsius", "fahrenheit"]
assert gemini._tools[0].function_declarations[0].parameters.required == ["location"]
assert isinstance(gemini._model, GenerativeModel)


Expand All @@ -195,22 +177,11 @@ def test_run():
def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001
return {"weather": "sunny", "temperature": 21.8, "unit": unit}

get_current_weather_func = FunctionDeclaration(
name="get_current_weather",
description="Get the current weather in a given location",
parameters={
"type_": "OBJECT",
"properties": {
"location": {"type_": "STRING", "description": "The city and state, e.g. San Francisco, CA"},
"unit": {
"type_": "STRING",
"enum": [
"celsius",
"fahrenheit",
],
},
},
"required": ["location"],
get_current_weather_func = FunctionDeclaration.from_function(
get_current_weather,
descriptions={
"location": "The city and state, e.g. San Francisco, CA",
"unit": "The temperature unit of measurement, e.g. celsius or fahrenheit",
},
)

Expand Down