-
Notifications
You must be signed in to change notification settings - Fork 127
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
b2d5269
commit 6f640ad
Showing
2 changed files
with
391 additions
and
0 deletions.
There are no files selected for viewing
211 changes: 211 additions & 0 deletions
211
integrations/gemini-haystack/tests/generators/chat/test_chat_gemini.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,211 @@ | ||
import os | ||
from unittest.mock import patch | ||
|
||
from haystack.dataclasses.chat_message import ChatMessage | ||
from google.generativeai import GenerationConfig, GenerativeModel | ||
from google.generativeai.types import HarmBlockThreshold, HarmCategory | ||
from google.ai.generativelanguage import FunctionDeclaration, Tool | ||
import pytest | ||
|
||
from gemini_haystack.generators.chat.gemini import GoogleAIGeminiChatGenerator | ||
|
||
|
||
def test_init(): | ||
generation_config = GenerationConfig( | ||
candidate_count=1, | ||
stop_sequences=["stop"], | ||
max_output_tokens=10, | ||
temperature=0.5, | ||
top_p=0.5, | ||
top_k=0.5, | ||
) | ||
safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS: HarmBlockThreshold.BLOCK_ONLY_HIGH} | ||
get_current_weather_func = FunctionDeclaration( | ||
name="get_current_weather", | ||
description="Get the current weather in a given location", | ||
parameters={ | ||
"type_": "OBJECT", | ||
"properties": { | ||
"location": {"type_": "STRING", "description": "The city and state, e.g. San Francisco, CA"}, | ||
"unit": { | ||
"type_": "STRING", | ||
"enum": [ | ||
"celsius", | ||
"fahrenheit", | ||
], | ||
}, | ||
}, | ||
"required": ["location"], | ||
}, | ||
) | ||
|
||
tool = Tool(function_declarations=[get_current_weather_func]) | ||
with patch("gemini_haystack.generators.chat.gemini.genai.configure") as mock_genai_configure: | ||
gemini = GoogleAIGeminiChatGenerator( | ||
generation_config=generation_config, | ||
safety_settings=safety_settings, | ||
tools=[tool], | ||
) | ||
mock_genai_configure.assert_called_once_with(api_key=None) | ||
assert gemini._model_name == "gemini-pro-vision" | ||
assert gemini._generation_config == generation_config | ||
assert gemini._safety_settings == safety_settings | ||
assert gemini._tools == [tool] | ||
assert isinstance(gemini._model, GenerativeModel) | ||
|
||
|
||
def test_to_dict(): | ||
generation_config = GenerationConfig( | ||
candidate_count=1, | ||
stop_sequences=["stop"], | ||
max_output_tokens=10, | ||
temperature=0.5, | ||
top_p=0.5, | ||
top_k=0.5, | ||
) | ||
safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS: HarmBlockThreshold.BLOCK_ONLY_HIGH} | ||
get_current_weather_func = FunctionDeclaration( | ||
name="get_current_weather", | ||
description="Get the current weather in a given location", | ||
parameters={ | ||
"type_": "OBJECT", | ||
"properties": { | ||
"location": {"type_": "STRING", "description": "The city and state, e.g. San Francisco, CA"}, | ||
"unit": { | ||
"type_": "STRING", | ||
"enum": [ | ||
"celsius", | ||
"fahrenheit", | ||
], | ||
}, | ||
}, | ||
"required": ["location"], | ||
}, | ||
) | ||
|
||
tool = Tool(function_declarations=[get_current_weather_func]) | ||
|
||
with patch("gemini_haystack.generators.chat.gemini.genai.configure"): | ||
gemini = GoogleAIGeminiChatGenerator( | ||
generation_config=generation_config, | ||
safety_settings=safety_settings, | ||
tools=[tool], | ||
) | ||
assert gemini.to_dict() == { | ||
"type": "gemini_haystack.generators.chat.gemini.GoogleAIGeminiChatGenerator", | ||
"init_parameters": { | ||
"model": "gemini-pro-vision", | ||
"generation_config": { | ||
"temperature": 0.5, | ||
"top_p": 0.5, | ||
"top_k": 0.5, | ||
"candidate_count": 1, | ||
"max_output_tokens": 10, | ||
"stop_sequences": ["stop"], | ||
}, | ||
"safety_settings": {6: 3}, | ||
"tools": [ | ||
b"\n\xad\x01\n\x13get_current_weather\x12+Get the current weather in a given location\x1ai\x08\x06:\x1f\n\x04unit\x12\x17\x08\x01*\x07celsius*\nfahrenheit::\n\x08location\x12.\x08\x01\x1a*The city and state, e.g. San Francisco, CAB\x08location" | ||
], | ||
}, | ||
} | ||
|
||
|
||
def test_from_dict(): | ||
with patch("gemini_haystack.generators.chat.gemini.genai.configure"): | ||
gemini = GoogleAIGeminiChatGenerator.from_dict( | ||
{ | ||
"type": "gemini_haystack.generators.chat.gemini.GoogleAIGeminiChatGenerator", | ||
"init_parameters": { | ||
"model": "gemini-pro-vision", | ||
"generation_config": { | ||
"temperature": 0.5, | ||
"top_p": 0.5, | ||
"top_k": 0.5, | ||
"candidate_count": 1, | ||
"max_output_tokens": 10, | ||
"stop_sequences": ["stop"], | ||
}, | ||
"safety_settings": {6: 3}, | ||
"tools": [ | ||
b"\n\xad\x01\n\x13get_current_weather\x12+Get the current weather in a given location\x1ai\x08\x06:\x1f\n\x04unit\x12\x17\x08\x01*\x07celsius*\nfahrenheit::\n\x08location\x12.\x08\x01\x1a*The city and state, e.g. San Francisco, CAB\x08location" | ||
], | ||
}, | ||
} | ||
) | ||
|
||
assert gemini._model_name == "gemini-pro-vision" | ||
assert gemini._generation_config == GenerationConfig( | ||
candidate_count=1, | ||
stop_sequences=["stop"], | ||
max_output_tokens=10, | ||
temperature=0.5, | ||
top_p=0.5, | ||
top_k=0.5, | ||
) | ||
assert gemini._safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS: HarmBlockThreshold.BLOCK_ONLY_HIGH} | ||
assert gemini._tools == [ | ||
Tool( | ||
function_declarations=[ | ||
FunctionDeclaration( | ||
name="get_current_weather", | ||
description="Get the current weather in a given location", | ||
parameters={ | ||
"type_": "OBJECT", | ||
"properties": { | ||
"location": { | ||
"type_": "STRING", | ||
"description": "The city and state, e.g. San Francisco, CA", | ||
}, | ||
"unit": { | ||
"type_": "STRING", | ||
"enum": [ | ||
"celsius", | ||
"fahrenheit", | ||
], | ||
}, | ||
}, | ||
"required": ["location"], | ||
}, | ||
) | ||
] | ||
) | ||
] | ||
assert isinstance(gemini._model, GenerativeModel) | ||
|
||
|
||
@pytest.mark.skipif("GOOGLE_API_KEY" not in os.environ, reason="GOOGLE_API_KEY not set") | ||
def test_run(): | ||
def get_current_weather(location: str, unit: str = "celsius"): | ||
return {"weather": "sunny", "temperature": 21.8, "unit": unit} | ||
|
||
get_current_weather_func = FunctionDeclaration( | ||
name="get_current_weather", | ||
description="Get the current weather in a given location", | ||
parameters={ | ||
"type_": "OBJECT", | ||
"properties": { | ||
"location": {"type_": "STRING", "description": "The city and state, e.g. San Francisco, CA"}, | ||
"unit": { | ||
"type_": "STRING", | ||
"enum": [ | ||
"celsius", | ||
"fahrenheit", | ||
], | ||
}, | ||
}, | ||
"required": ["location"], | ||
}, | ||
) | ||
|
||
tool = Tool(function_declarations=[get_current_weather_func]) | ||
gemini_chat = GoogleAIGeminiChatGenerator(model="gemini-pro", tools=[tool]) | ||
messages = [ChatMessage.from_user(content="What is the temperature in celsius in Berlin?")] | ||
res = gemini_chat.run(messages=messages) | ||
assert len(res["replies"]) > 0 | ||
|
||
weather = get_current_weather(**res["replies"][0].content) | ||
messages += res["replies"] + [ChatMessage.from_function(content=weather, name="get_current_weather")] | ||
|
||
res = gemini_chat.run(messages=messages) | ||
assert len(res["replies"]) > 0 |
180 changes: 180 additions & 0 deletions
180
integrations/gemini-haystack/tests/generators/test_gemini.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,180 @@ | ||
import os | ||
from unittest.mock import patch | ||
|
||
from google.generativeai import GenerationConfig, GenerativeModel | ||
from google.generativeai.types import HarmBlockThreshold, HarmCategory | ||
from google.ai.generativelanguage import FunctionDeclaration, Tool | ||
import pytest | ||
|
||
from gemini_haystack.generators.gemini import GoogleAIGeminiGenerator | ||
|
||
|
||
def test_init(): | ||
generation_config = GenerationConfig( | ||
candidate_count=1, | ||
stop_sequences=["stop"], | ||
max_output_tokens=10, | ||
temperature=0.5, | ||
top_p=0.5, | ||
top_k=0.5, | ||
) | ||
safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS: HarmBlockThreshold.BLOCK_ONLY_HIGH} | ||
get_current_weather_func = FunctionDeclaration( | ||
name="get_current_weather", | ||
description="Get the current weather in a given location", | ||
parameters={ | ||
"type_": "OBJECT", | ||
"properties": { | ||
"location": {"type_": "STRING", "description": "The city and state, e.g. San Francisco, CA"}, | ||
"unit": { | ||
"type_": "STRING", | ||
"enum": [ | ||
"celsius", | ||
"fahrenheit", | ||
], | ||
}, | ||
}, | ||
"required": ["location"], | ||
}, | ||
) | ||
|
||
tool = Tool(function_declarations=[get_current_weather_func]) | ||
with patch("gemini_haystack.generators.gemini.genai.configure") as mock_genai_configure: | ||
gemini = GoogleAIGeminiGenerator( | ||
generation_config=generation_config, | ||
safety_settings=safety_settings, | ||
tools=[tool], | ||
) | ||
mock_genai_configure.assert_called_once_with(api_key=None) | ||
assert gemini._model_name == "gemini-pro-vision" | ||
assert gemini._generation_config == generation_config | ||
assert gemini._safety_settings == safety_settings | ||
assert gemini._tools == [tool] | ||
assert isinstance(gemini._model, GenerativeModel) | ||
|
||
|
||
def test_to_dict(): | ||
generation_config = GenerationConfig( | ||
candidate_count=1, | ||
stop_sequences=["stop"], | ||
max_output_tokens=10, | ||
temperature=0.5, | ||
top_p=0.5, | ||
top_k=0.5, | ||
) | ||
safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS: HarmBlockThreshold.BLOCK_ONLY_HIGH} | ||
get_current_weather_func = FunctionDeclaration( | ||
name="get_current_weather", | ||
description="Get the current weather in a given location", | ||
parameters={ | ||
"type_": "OBJECT", | ||
"properties": { | ||
"location": {"type_": "STRING", "description": "The city and state, e.g. San Francisco, CA"}, | ||
"unit": { | ||
"type_": "STRING", | ||
"enum": [ | ||
"celsius", | ||
"fahrenheit", | ||
], | ||
}, | ||
}, | ||
"required": ["location"], | ||
}, | ||
) | ||
|
||
tool = Tool(function_declarations=[get_current_weather_func]) | ||
|
||
with patch("gemini_haystack.generators.gemini.genai.configure"): | ||
gemini = GoogleAIGeminiGenerator( | ||
generation_config=generation_config, | ||
safety_settings=safety_settings, | ||
tools=[tool], | ||
) | ||
assert gemini.to_dict() == { | ||
"type": "gemini_haystack.generators.gemini.GoogleAIGeminiGenerator", | ||
"init_parameters": { | ||
"model": "gemini-pro-vision", | ||
"generation_config": { | ||
"temperature": 0.5, | ||
"top_p": 0.5, | ||
"top_k": 0.5, | ||
"candidate_count": 1, | ||
"max_output_tokens": 10, | ||
"stop_sequences": ["stop"], | ||
}, | ||
"safety_settings": {6: 3}, | ||
"tools": [ | ||
b"\n\xad\x01\n\x13get_current_weather\x12+Get the current weather in a given location\x1ai\x08\x06:\x1f\n\x04unit\x12\x17\x08\x01*\x07celsius*\nfahrenheit::\n\x08location\x12.\x08\x01\x1a*The city and state, e.g. San Francisco, CAB\x08location" | ||
], | ||
}, | ||
} | ||
|
||
|
||
def test_from_dict(): | ||
with patch("gemini_haystack.generators.gemini.genai.configure"): | ||
gemini = GoogleAIGeminiGenerator.from_dict( | ||
{ | ||
"type": "gemini_haystack.generators.gemini.GoogleAIGeminiGenerator", | ||
"init_parameters": { | ||
"model": "gemini-pro-vision", | ||
"generation_config": { | ||
"temperature": 0.5, | ||
"top_p": 0.5, | ||
"top_k": 0.5, | ||
"candidate_count": 1, | ||
"max_output_tokens": 10, | ||
"stop_sequences": ["stop"], | ||
}, | ||
"safety_settings": {6: 3}, | ||
"tools": [ | ||
b"\n\xad\x01\n\x13get_current_weather\x12+Get the current weather in a given location\x1ai\x08\x06:\x1f\n\x04unit\x12\x17\x08\x01*\x07celsius*\nfahrenheit::\n\x08location\x12.\x08\x01\x1a*The city and state, e.g. San Francisco, CAB\x08location" | ||
], | ||
}, | ||
} | ||
) | ||
|
||
assert gemini._model_name == "gemini-pro-vision" | ||
assert gemini._generation_config == GenerationConfig( | ||
candidate_count=1, | ||
stop_sequences=["stop"], | ||
max_output_tokens=10, | ||
temperature=0.5, | ||
top_p=0.5, | ||
top_k=0.5, | ||
) | ||
assert gemini._safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS: HarmBlockThreshold.BLOCK_ONLY_HIGH} | ||
assert gemini._tools == [ | ||
Tool( | ||
function_declarations=[ | ||
FunctionDeclaration( | ||
name="get_current_weather", | ||
description="Get the current weather in a given location", | ||
parameters={ | ||
"type_": "OBJECT", | ||
"properties": { | ||
"location": { | ||
"type_": "STRING", | ||
"description": "The city and state, e.g. San Francisco, CA", | ||
}, | ||
"unit": { | ||
"type_": "STRING", | ||
"enum": [ | ||
"celsius", | ||
"fahrenheit", | ||
], | ||
}, | ||
}, | ||
"required": ["location"], | ||
}, | ||
) | ||
] | ||
) | ||
] | ||
assert isinstance(gemini._model, GenerativeModel) | ||
|
||
|
||
@pytest.mark.skipif("GOOGLE_API_KEY" not in os.environ, reason="GOOGLE_API_KEY not set") | ||
def test_run(): | ||
gemini = GoogleAIGeminiGenerator(model="gemini-pro") | ||
res = gemini.run("Tell me something cool") | ||
assert len(res["answers"]) > 0 |