Skip to content

Commit

Permalink
fix: GoogleAIGeminiGenerator - remove support for tools and change …
Browse files Browse the repository at this point in the history
…output type (#1177)

* GoogleAIGeminiGenerator - rm support for tools

* simplify
  • Loading branch information
anakin87 authored Nov 12, 2024
1 parent 4cfee2d commit 946e154
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 98 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Any, Callable, Dict, List, Optional, Union

import google.generativeai as genai
from google.ai.generativelanguage import Content, Part, Tool
from google.ai.generativelanguage import Content, Part
from google.generativeai import GenerationConfig, GenerativeModel
from google.generativeai.types import GenerateContentResponse, HarmBlockThreshold, HarmCategory
from haystack.core.component import component
Expand Down Expand Up @@ -62,14 +62,23 @@ class GoogleAIGeminiGenerator:
```
"""

def __new__(cls, *_, **kwargs):
if "tools" in kwargs:
msg = (
"GoogleAIGeminiGenerator does not support the `tools` parameter. "
" Use GoogleAIGeminiChatGenerator instead."
)
raise TypeError(msg)
return super(GoogleAIGeminiGenerator, cls).__new__(cls) # noqa: UP008
# super(__class__, cls) is needed because of the component decorator

def __init__(
self,
*,
api_key: Secret = Secret.from_env_var("GOOGLE_API_KEY"), # noqa: B008
model: str = "gemini-1.5-flash",
generation_config: Optional[Union[GenerationConfig, Dict[str, Any]]] = None,
safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None,
tools: Optional[List[Tool]] = None,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
):
"""
Expand All @@ -86,7 +95,6 @@ def __init__(
:param safety_settings: The safety settings to use.
A dictionary with `HarmCategory` as keys and `HarmBlockThreshold` as values.
For more information, see [the API reference](https://ai.google.dev/api)
:param tools: A list of Tool objects that can be used for [Function calling](https://ai.google.dev/docs/function_calling).
:param streaming_callback: A callback function that is called when a new token is received from the stream.
The callback function accepts StreamingChunk as an argument.
"""
Expand All @@ -96,8 +104,7 @@ def __init__(
self._model_name = model
self._generation_config = generation_config
self._safety_settings = safety_settings
self._tools = tools
self._model = GenerativeModel(self._model_name, tools=self._tools)
self._model = GenerativeModel(self._model_name)
self._streaming_callback = streaming_callback

def _generation_config_to_dict(self, config: Union[GenerationConfig, Dict[str, Any]]) -> Dict[str, Any]:
Expand Down Expand Up @@ -126,11 +133,8 @@ def to_dict(self) -> Dict[str, Any]:
model=self._model_name,
generation_config=self._generation_config,
safety_settings=self._safety_settings,
tools=self._tools,
streaming_callback=callback_name,
)
if (tools := data["init_parameters"].get("tools")) is not None:
data["init_parameters"]["tools"] = [Tool.serialize(t) for t in tools]
if (generation_config := data["init_parameters"].get("generation_config")) is not None:
data["init_parameters"]["generation_config"] = self._generation_config_to_dict(generation_config)
if (safety_settings := data["init_parameters"].get("safety_settings")) is not None:
Expand All @@ -149,8 +153,6 @@ def from_dict(cls, data: Dict[str, Any]) -> "GoogleAIGeminiGenerator":
"""
deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"])

if (tools := data["init_parameters"].get("tools")) is not None:
data["init_parameters"]["tools"] = [Tool.deserialize(t) for t in tools]
if (generation_config := data["init_parameters"].get("generation_config")) is not None:
data["init_parameters"]["generation_config"] = GenerationConfig(**generation_config)
if (safety_settings := data["init_parameters"].get("safety_settings")) is not None:
Expand Down Expand Up @@ -178,7 +180,7 @@ def _convert_part(self, part: Union[str, ByteStream, Part]) -> Part:
msg = f"Unsupported type {type(part)} for part {part}"
raise ValueError(msg)

@component.output_types(replies=List[Union[str, Dict[str, str]]])
@component.output_types(replies=List[str])
def run(
self,
parts: Variadic[Union[str, ByteStream, Part]],
Expand All @@ -192,7 +194,7 @@ def run(
:param streaming_callback: A callback function that is called when a new token is received from the stream.
:returns:
A dictionary containing the following key:
- `replies`: A list of strings or dictionaries with function calls.
- `replies`: A list of strings containing the generated responses.
"""

# check if streaming_callback is passed
Expand Down Expand Up @@ -221,12 +223,6 @@ def _get_response(self, response_body: GenerateContentResponse) -> List[str]:
for part in candidate.content.parts:
if part.text != "":
replies.append(part.text)
elif part.function_call is not None:
function_call = {
"name": part.function_call.name,
"args": dict(part.function_call.args.items()),
}
replies.append(function_call)
return replies

def _get_stream_response(
Expand Down
85 changes: 5 additions & 80 deletions integrations/google_ai/tests/generators/test_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,12 @@
from unittest.mock import patch

import pytest
from google.ai.generativelanguage import FunctionDeclaration, Tool
from google.generativeai import GenerationConfig, GenerativeModel
from google.generativeai.types import HarmBlockThreshold, HarmCategory
from haystack.dataclasses import StreamingChunk

from haystack_integrations.components.generators.google_ai import GoogleAIGeminiGenerator

GET_CURRENT_WEATHER_FUNC = FunctionDeclaration(
name="get_current_weather",
description="Get the current weather in a given location",
parameters={
"type_": "OBJECT",
"properties": {
"location": {"type_": "STRING", "description": "The city and state, e.g. San Francisco, CA"},
"unit": {
"type_": "STRING",
"enum": [
"celsius",
"fahrenheit",
],
},
},
"required": ["location"],
},
)


def test_init(monkeypatch):
monkeypatch.setenv("GOOGLE_API_KEY", "test")
Expand All @@ -41,40 +21,24 @@ def test_init(monkeypatch):
top_k=0.5,
)
safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH}
get_current_weather_func = FunctionDeclaration(
name="get_current_weather",
description="Get the current weather in a given location",
parameters={
"type_": "OBJECT",
"properties": {
"location": {"type_": "STRING", "description": "The city and state, e.g. San Francisco, CA"},
"unit": {
"type_": "STRING",
"enum": [
"celsius",
"fahrenheit",
],
},
},
"required": ["location"],
},
)

tool = Tool(function_declarations=[get_current_weather_func])
with patch("haystack_integrations.components.generators.google_ai.gemini.genai.configure") as mock_genai_configure:
gemini = GoogleAIGeminiGenerator(
generation_config=generation_config,
safety_settings=safety_settings,
tools=[tool],
)
mock_genai_configure.assert_called_once_with(api_key="test")
assert gemini._model_name == "gemini-1.5-flash"
assert gemini._generation_config == generation_config
assert gemini._safety_settings == safety_settings
assert gemini._tools == [tool]
assert isinstance(gemini._model, GenerativeModel)


def test_init_fails_with_tools():
with pytest.raises(TypeError, match="GoogleAIGeminiGenerator does not support the `tools` parameter."):
GoogleAIGeminiGenerator(tools=["tool1", "tool2"])


def test_to_dict(monkeypatch):
monkeypatch.setenv("GOOGLE_API_KEY", "test")

Expand All @@ -88,7 +52,6 @@ def test_to_dict(monkeypatch):
"generation_config": None,
"safety_settings": None,
"streaming_callback": None,
"tools": None,
},
}

Expand All @@ -105,32 +68,11 @@ def test_to_dict_with_param(monkeypatch):
top_k=2,
)
safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH}
get_current_weather_func = FunctionDeclaration(
name="get_current_weather",
description="Get the current weather in a given location",
parameters={
"type_": "OBJECT",
"properties": {
"location": {"type_": "STRING", "description": "The city and state, e.g. San Francisco, CA"},
"unit": {
"type_": "STRING",
"enum": [
"celsius",
"fahrenheit",
],
},
},
"required": ["location"],
},
)

tool = Tool(function_declarations=[get_current_weather_func])

with patch("haystack_integrations.components.generators.google_ai.gemini.genai.configure"):
gemini = GoogleAIGeminiGenerator(
generation_config=generation_config,
safety_settings=safety_settings,
tools=[tool],
)
assert gemini.to_dict() == {
"type": "haystack_integrations.components.generators.google_ai.gemini.GoogleAIGeminiGenerator",
Expand All @@ -147,11 +89,6 @@ def test_to_dict_with_param(monkeypatch):
},
"safety_settings": {10: 3},
"streaming_callback": None,
"tools": [
b"\n\xad\x01\n\x13get_current_weather\x12+Get the current weather in a given location\x1ai"
b"\x08\x06:\x1f\n\x04unit\x12\x17\x08\x01*\x07celsius*\nfahrenheit::\n\x08location\x12.\x08"
b"\x01\x1a*The city and state, e.g. San Francisco, CAB\x08location"
],
},
}

Expand All @@ -175,11 +112,6 @@ def test_from_dict_with_param(monkeypatch):
},
"safety_settings": {10: 3},
"streaming_callback": None,
"tools": [
b"\n\xad\x01\n\x13get_current_weather\x12+Get the current weather in a given location\x1ai"
b"\x08\x06:\x1f\n\x04unit\x12\x17\x08\x01*\x07celsius*\nfahrenheit::\n\x08location\x12.\x08"
b"\x01\x1a*The city and state, e.g. San Francisco, CAB\x08location"
],
},
}
)
Expand All @@ -194,7 +126,6 @@ def test_from_dict_with_param(monkeypatch):
top_k=0.5,
)
assert gemini._safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH}
assert gemini._tools == [Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC])]
assert isinstance(gemini._model, GenerativeModel)


Expand All @@ -217,11 +148,6 @@ def test_from_dict(monkeypatch):
},
"safety_settings": {10: 3},
"streaming_callback": None,
"tools": [
b"\n\xad\x01\n\x13get_current_weather\x12+Get the current weather in a given location\x1ai"
b"\x08\x06:\x1f\n\x04unit\x12\x17\x08\x01*\x07celsius*\nfahrenheit::\n\x08location\x12.\x08"
b"\x01\x1a*The city and state, e.g. San Francisco, CAB\x08location"
],
},
}
)
Expand All @@ -236,7 +162,6 @@ def test_from_dict(monkeypatch):
top_k=0.5,
)
assert gemini._safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH}
assert gemini._tools == [Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC])]
assert isinstance(gemini._model, GenerativeModel)


Expand Down

0 comments on commit 946e154

Please sign in to comment.