From 946e1540386d9d03ca670e660f63c6572e1d1b5b Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Tue, 12 Nov 2024 12:55:01 +0100 Subject: [PATCH] fix: `GoogleAIGeminiGenerator` - remove support for tools and change output type (#1177) * GoogleAIGeminiGenerator - rm support for tools * simplify --- .../components/generators/google_ai/gemini.py | 32 +++---- .../google_ai/tests/generators/test_gemini.py | 85 ++----------------- 2 files changed, 19 insertions(+), 98 deletions(-) diff --git a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/gemini.py b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/gemini.py index 218e16c4c..b032169df 100644 --- a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/gemini.py +++ b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/gemini.py @@ -2,7 +2,7 @@ from typing import Any, Callable, Dict, List, Optional, Union import google.generativeai as genai -from google.ai.generativelanguage import Content, Part, Tool +from google.ai.generativelanguage import Content, Part from google.generativeai import GenerationConfig, GenerativeModel from google.generativeai.types import GenerateContentResponse, HarmBlockThreshold, HarmCategory from haystack.core.component import component @@ -62,6 +62,16 @@ class GoogleAIGeminiGenerator: ``` """ + def __new__(cls, *_, **kwargs): + if "tools" in kwargs: + msg = ( + "GoogleAIGeminiGenerator does not support the `tools` parameter. " + " Use GoogleAIGeminiChatGenerator instead." + ) + raise TypeError(msg) + return super(GoogleAIGeminiGenerator, cls).__new__(cls) # noqa: UP008 + # super(__class__, cls) is needed because of the component decorator + def __init__( self, *, @@ -69,7 +79,6 @@ def __init__( model: str = "gemini-1.5-flash", generation_config: Optional[Union[GenerationConfig, Dict[str, Any]]] = None, safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None, - tools: Optional[List[Tool]] = None, streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, ): """ @@ -86,7 +95,6 @@ def __init__( :param safety_settings: The safety settings to use. A dictionary with `HarmCategory` as keys and `HarmBlockThreshold` as values. For more information, see [the API reference](https://ai.google.dev/api) - :param tools: A list of Tool objects that can be used for [Function calling](https://ai.google.dev/docs/function_calling). :param streaming_callback: A callback function that is called when a new token is received from the stream. The callback function accepts StreamingChunk as an argument. """ @@ -96,8 +104,7 @@ def __init__( self._model_name = model self._generation_config = generation_config self._safety_settings = safety_settings - self._tools = tools - self._model = GenerativeModel(self._model_name, tools=self._tools) + self._model = GenerativeModel(self._model_name) self._streaming_callback = streaming_callback def _generation_config_to_dict(self, config: Union[GenerationConfig, Dict[str, Any]]) -> Dict[str, Any]: @@ -126,11 +133,8 @@ def to_dict(self) -> Dict[str, Any]: model=self._model_name, generation_config=self._generation_config, safety_settings=self._safety_settings, - tools=self._tools, streaming_callback=callback_name, ) - if (tools := data["init_parameters"].get("tools")) is not None: - data["init_parameters"]["tools"] = [Tool.serialize(t) for t in tools] if (generation_config := data["init_parameters"].get("generation_config")) is not None: data["init_parameters"]["generation_config"] = self._generation_config_to_dict(generation_config) if (safety_settings := data["init_parameters"].get("safety_settings")) is not None: @@ -149,8 +153,6 @@ def from_dict(cls, data: Dict[str, Any]) -> "GoogleAIGeminiGenerator": """ deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"]) - if (tools := data["init_parameters"].get("tools")) is not None: - data["init_parameters"]["tools"] = [Tool.deserialize(t) for t in tools] if (generation_config := data["init_parameters"].get("generation_config")) is not None: data["init_parameters"]["generation_config"] = GenerationConfig(**generation_config) if (safety_settings := data["init_parameters"].get("safety_settings")) is not None: @@ -178,7 +180,7 @@ def _convert_part(self, part: Union[str, ByteStream, Part]) -> Part: msg = f"Unsupported type {type(part)} for part {part}" raise ValueError(msg) - @component.output_types(replies=List[Union[str, Dict[str, str]]]) + @component.output_types(replies=List[str]) def run( self, parts: Variadic[Union[str, ByteStream, Part]], @@ -192,7 +194,7 @@ def run( :param streaming_callback: A callback function that is called when a new token is received from the stream. :returns: A dictionary containing the following key: - - `replies`: A list of strings or dictionaries with function calls. + - `replies`: A list of strings containing the generated responses. """ # check if streaming_callback is passed @@ -221,12 +223,6 @@ def _get_response(self, response_body: GenerateContentResponse) -> List[str]: for part in candidate.content.parts: if part.text != "": replies.append(part.text) - elif part.function_call is not None: - function_call = { - "name": part.function_call.name, - "args": dict(part.function_call.args.items()), - } - replies.append(function_call) return replies def _get_stream_response( diff --git a/integrations/google_ai/tests/generators/test_gemini.py b/integrations/google_ai/tests/generators/test_gemini.py index 7206b7a43..07d194a59 100644 --- a/integrations/google_ai/tests/generators/test_gemini.py +++ b/integrations/google_ai/tests/generators/test_gemini.py @@ -2,32 +2,12 @@ from unittest.mock import patch import pytest -from google.ai.generativelanguage import FunctionDeclaration, Tool from google.generativeai import GenerationConfig, GenerativeModel from google.generativeai.types import HarmBlockThreshold, HarmCategory from haystack.dataclasses import StreamingChunk from haystack_integrations.components.generators.google_ai import GoogleAIGeminiGenerator -GET_CURRENT_WEATHER_FUNC = FunctionDeclaration( - name="get_current_weather", - description="Get the current weather in a given location", - parameters={ - "type_": "OBJECT", - "properties": { - "location": {"type_": "STRING", "description": "The city and state, e.g. San Francisco, CA"}, - "unit": { - "type_": "STRING", - "enum": [ - "celsius", - "fahrenheit", - ], - }, - }, - "required": ["location"], - }, -) - def test_init(monkeypatch): monkeypatch.setenv("GOOGLE_API_KEY", "test") @@ -41,40 +21,24 @@ def test_init(monkeypatch): top_k=0.5, ) safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} - get_current_weather_func = FunctionDeclaration( - name="get_current_weather", - description="Get the current weather in a given location", - parameters={ - "type_": "OBJECT", - "properties": { - "location": {"type_": "STRING", "description": "The city and state, e.g. San Francisco, CA"}, - "unit": { - "type_": "STRING", - "enum": [ - "celsius", - "fahrenheit", - ], - }, - }, - "required": ["location"], - }, - ) - tool = Tool(function_declarations=[get_current_weather_func]) with patch("haystack_integrations.components.generators.google_ai.gemini.genai.configure") as mock_genai_configure: gemini = GoogleAIGeminiGenerator( generation_config=generation_config, safety_settings=safety_settings, - tools=[tool], ) mock_genai_configure.assert_called_once_with(api_key="test") assert gemini._model_name == "gemini-1.5-flash" assert gemini._generation_config == generation_config assert gemini._safety_settings == safety_settings - assert gemini._tools == [tool] assert isinstance(gemini._model, GenerativeModel) +def test_init_fails_with_tools(): + with pytest.raises(TypeError, match="GoogleAIGeminiGenerator does not support the `tools` parameter."): + GoogleAIGeminiGenerator(tools=["tool1", "tool2"]) + + def test_to_dict(monkeypatch): monkeypatch.setenv("GOOGLE_API_KEY", "test") @@ -88,7 +52,6 @@ def test_to_dict(monkeypatch): "generation_config": None, "safety_settings": None, "streaming_callback": None, - "tools": None, }, } @@ -105,32 +68,11 @@ def test_to_dict_with_param(monkeypatch): top_k=2, ) safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} - get_current_weather_func = FunctionDeclaration( - name="get_current_weather", - description="Get the current weather in a given location", - parameters={ - "type_": "OBJECT", - "properties": { - "location": {"type_": "STRING", "description": "The city and state, e.g. San Francisco, CA"}, - "unit": { - "type_": "STRING", - "enum": [ - "celsius", - "fahrenheit", - ], - }, - }, - "required": ["location"], - }, - ) - - tool = Tool(function_declarations=[get_current_weather_func]) with patch("haystack_integrations.components.generators.google_ai.gemini.genai.configure"): gemini = GoogleAIGeminiGenerator( generation_config=generation_config, safety_settings=safety_settings, - tools=[tool], ) assert gemini.to_dict() == { "type": "haystack_integrations.components.generators.google_ai.gemini.GoogleAIGeminiGenerator", @@ -147,11 +89,6 @@ def test_to_dict_with_param(monkeypatch): }, "safety_settings": {10: 3}, "streaming_callback": None, - "tools": [ - b"\n\xad\x01\n\x13get_current_weather\x12+Get the current weather in a given location\x1ai" - b"\x08\x06:\x1f\n\x04unit\x12\x17\x08\x01*\x07celsius*\nfahrenheit::\n\x08location\x12.\x08" - b"\x01\x1a*The city and state, e.g. San Francisco, CAB\x08location" - ], }, } @@ -175,11 +112,6 @@ def test_from_dict_with_param(monkeypatch): }, "safety_settings": {10: 3}, "streaming_callback": None, - "tools": [ - b"\n\xad\x01\n\x13get_current_weather\x12+Get the current weather in a given location\x1ai" - b"\x08\x06:\x1f\n\x04unit\x12\x17\x08\x01*\x07celsius*\nfahrenheit::\n\x08location\x12.\x08" - b"\x01\x1a*The city and state, e.g. San Francisco, CAB\x08location" - ], }, } ) @@ -194,7 +126,6 @@ def test_from_dict_with_param(monkeypatch): top_k=0.5, ) assert gemini._safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} - assert gemini._tools == [Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC])] assert isinstance(gemini._model, GenerativeModel) @@ -217,11 +148,6 @@ def test_from_dict(monkeypatch): }, "safety_settings": {10: 3}, "streaming_callback": None, - "tools": [ - b"\n\xad\x01\n\x13get_current_weather\x12+Get the current weather in a given location\x1ai" - b"\x08\x06:\x1f\n\x04unit\x12\x17\x08\x01*\x07celsius*\nfahrenheit::\n\x08location\x12.\x08" - b"\x01\x1a*The city and state, e.g. San Francisco, CAB\x08location" - ], }, } ) @@ -236,7 +162,6 @@ def test_from_dict(monkeypatch): top_k=0.5, ) assert gemini._safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} - assert gemini._tools == [Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC])] assert isinstance(gemini._model, GenerativeModel)