Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
silvanocerza committed Jan 2, 2024
1 parent b2d5269 commit 6f640ad
Show file tree
Hide file tree
Showing 2 changed files with 391 additions and 0 deletions.
211 changes: 211 additions & 0 deletions integrations/gemini-haystack/tests/generators/chat/test_chat_gemini.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
import os
from unittest.mock import patch

from haystack.dataclasses.chat_message import ChatMessage
from google.generativeai import GenerationConfig, GenerativeModel
from google.generativeai.types import HarmBlockThreshold, HarmCategory
from google.ai.generativelanguage import FunctionDeclaration, Tool
import pytest

from gemini_haystack.generators.chat.gemini import GoogleAIGeminiChatGenerator


def test_init():
generation_config = GenerationConfig(
candidate_count=1,
stop_sequences=["stop"],
max_output_tokens=10,
temperature=0.5,
top_p=0.5,
top_k=0.5,
)
safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS: HarmBlockThreshold.BLOCK_ONLY_HIGH}
get_current_weather_func = FunctionDeclaration(
name="get_current_weather",
description="Get the current weather in a given location",
parameters={
"type_": "OBJECT",
"properties": {
"location": {"type_": "STRING", "description": "The city and state, e.g. San Francisco, CA"},
"unit": {
"type_": "STRING",
"enum": [
"celsius",
"fahrenheit",
],
},
},
"required": ["location"],
},
)

tool = Tool(function_declarations=[get_current_weather_func])
with patch("gemini_haystack.generators.chat.gemini.genai.configure") as mock_genai_configure:
gemini = GoogleAIGeminiChatGenerator(
generation_config=generation_config,
safety_settings=safety_settings,
tools=[tool],
)
mock_genai_configure.assert_called_once_with(api_key=None)
assert gemini._model_name == "gemini-pro-vision"
assert gemini._generation_config == generation_config
assert gemini._safety_settings == safety_settings
assert gemini._tools == [tool]
assert isinstance(gemini._model, GenerativeModel)


def test_to_dict():
generation_config = GenerationConfig(
candidate_count=1,
stop_sequences=["stop"],
max_output_tokens=10,
temperature=0.5,
top_p=0.5,
top_k=0.5,
)
safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS: HarmBlockThreshold.BLOCK_ONLY_HIGH}
get_current_weather_func = FunctionDeclaration(
name="get_current_weather",
description="Get the current weather in a given location",
parameters={
"type_": "OBJECT",
"properties": {
"location": {"type_": "STRING", "description": "The city and state, e.g. San Francisco, CA"},
"unit": {
"type_": "STRING",
"enum": [
"celsius",
"fahrenheit",
],
},
},
"required": ["location"],
},
)

tool = Tool(function_declarations=[get_current_weather_func])

with patch("gemini_haystack.generators.chat.gemini.genai.configure"):
gemini = GoogleAIGeminiChatGenerator(
generation_config=generation_config,
safety_settings=safety_settings,
tools=[tool],
)
assert gemini.to_dict() == {
"type": "gemini_haystack.generators.chat.gemini.GoogleAIGeminiChatGenerator",
"init_parameters": {
"model": "gemini-pro-vision",
"generation_config": {
"temperature": 0.5,
"top_p": 0.5,
"top_k": 0.5,
"candidate_count": 1,
"max_output_tokens": 10,
"stop_sequences": ["stop"],
},
"safety_settings": {6: 3},
"tools": [
b"\n\xad\x01\n\x13get_current_weather\x12+Get the current weather in a given location\x1ai\x08\x06:\x1f\n\x04unit\x12\x17\x08\x01*\x07celsius*\nfahrenheit::\n\x08location\x12.\x08\x01\x1a*The city and state, e.g. San Francisco, CAB\x08location"
],
},
}


def test_from_dict():
with patch("gemini_haystack.generators.chat.gemini.genai.configure"):
gemini = GoogleAIGeminiChatGenerator.from_dict(
{
"type": "gemini_haystack.generators.chat.gemini.GoogleAIGeminiChatGenerator",
"init_parameters": {
"model": "gemini-pro-vision",
"generation_config": {
"temperature": 0.5,
"top_p": 0.5,
"top_k": 0.5,
"candidate_count": 1,
"max_output_tokens": 10,
"stop_sequences": ["stop"],
},
"safety_settings": {6: 3},
"tools": [
b"\n\xad\x01\n\x13get_current_weather\x12+Get the current weather in a given location\x1ai\x08\x06:\x1f\n\x04unit\x12\x17\x08\x01*\x07celsius*\nfahrenheit::\n\x08location\x12.\x08\x01\x1a*The city and state, e.g. San Francisco, CAB\x08location"
],
},
}
)

assert gemini._model_name == "gemini-pro-vision"
assert gemini._generation_config == GenerationConfig(
candidate_count=1,
stop_sequences=["stop"],
max_output_tokens=10,
temperature=0.5,
top_p=0.5,
top_k=0.5,
)
assert gemini._safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS: HarmBlockThreshold.BLOCK_ONLY_HIGH}
assert gemini._tools == [
Tool(
function_declarations=[
FunctionDeclaration(
name="get_current_weather",
description="Get the current weather in a given location",
parameters={
"type_": "OBJECT",
"properties": {
"location": {
"type_": "STRING",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {
"type_": "STRING",
"enum": [
"celsius",
"fahrenheit",
],
},
},
"required": ["location"],
},
)
]
)
]
assert isinstance(gemini._model, GenerativeModel)


@pytest.mark.skipif("GOOGLE_API_KEY" not in os.environ, reason="GOOGLE_API_KEY not set")
def test_run():
def get_current_weather(location: str, unit: str = "celsius"):
return {"weather": "sunny", "temperature": 21.8, "unit": unit}

get_current_weather_func = FunctionDeclaration(
name="get_current_weather",
description="Get the current weather in a given location",
parameters={
"type_": "OBJECT",
"properties": {
"location": {"type_": "STRING", "description": "The city and state, e.g. San Francisco, CA"},
"unit": {
"type_": "STRING",
"enum": [
"celsius",
"fahrenheit",
],
},
},
"required": ["location"],
},
)

tool = Tool(function_declarations=[get_current_weather_func])
gemini_chat = GoogleAIGeminiChatGenerator(model="gemini-pro", tools=[tool])
messages = [ChatMessage.from_user(content="What is the temperature in celsius in Berlin?")]
res = gemini_chat.run(messages=messages)
assert len(res["replies"]) > 0

weather = get_current_weather(**res["replies"][0].content)
messages += res["replies"] + [ChatMessage.from_function(content=weather, name="get_current_weather")]

res = gemini_chat.run(messages=messages)
assert len(res["replies"]) > 0
180 changes: 180 additions & 0 deletions integrations/gemini-haystack/tests/generators/test_gemini.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
import os
from unittest.mock import patch

from google.generativeai import GenerationConfig, GenerativeModel
from google.generativeai.types import HarmBlockThreshold, HarmCategory
from google.ai.generativelanguage import FunctionDeclaration, Tool
import pytest

from gemini_haystack.generators.gemini import GoogleAIGeminiGenerator


def test_init():
generation_config = GenerationConfig(
candidate_count=1,
stop_sequences=["stop"],
max_output_tokens=10,
temperature=0.5,
top_p=0.5,
top_k=0.5,
)
safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS: HarmBlockThreshold.BLOCK_ONLY_HIGH}
get_current_weather_func = FunctionDeclaration(
name="get_current_weather",
description="Get the current weather in a given location",
parameters={
"type_": "OBJECT",
"properties": {
"location": {"type_": "STRING", "description": "The city and state, e.g. San Francisco, CA"},
"unit": {
"type_": "STRING",
"enum": [
"celsius",
"fahrenheit",
],
},
},
"required": ["location"],
},
)

tool = Tool(function_declarations=[get_current_weather_func])
with patch("gemini_haystack.generators.gemini.genai.configure") as mock_genai_configure:
gemini = GoogleAIGeminiGenerator(
generation_config=generation_config,
safety_settings=safety_settings,
tools=[tool],
)
mock_genai_configure.assert_called_once_with(api_key=None)
assert gemini._model_name == "gemini-pro-vision"
assert gemini._generation_config == generation_config
assert gemini._safety_settings == safety_settings
assert gemini._tools == [tool]
assert isinstance(gemini._model, GenerativeModel)


def test_to_dict():
generation_config = GenerationConfig(
candidate_count=1,
stop_sequences=["stop"],
max_output_tokens=10,
temperature=0.5,
top_p=0.5,
top_k=0.5,
)
safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS: HarmBlockThreshold.BLOCK_ONLY_HIGH}
get_current_weather_func = FunctionDeclaration(
name="get_current_weather",
description="Get the current weather in a given location",
parameters={
"type_": "OBJECT",
"properties": {
"location": {"type_": "STRING", "description": "The city and state, e.g. San Francisco, CA"},
"unit": {
"type_": "STRING",
"enum": [
"celsius",
"fahrenheit",
],
},
},
"required": ["location"],
},
)

tool = Tool(function_declarations=[get_current_weather_func])

with patch("gemini_haystack.generators.gemini.genai.configure"):
gemini = GoogleAIGeminiGenerator(
generation_config=generation_config,
safety_settings=safety_settings,
tools=[tool],
)
assert gemini.to_dict() == {
"type": "gemini_haystack.generators.gemini.GoogleAIGeminiGenerator",
"init_parameters": {
"model": "gemini-pro-vision",
"generation_config": {
"temperature": 0.5,
"top_p": 0.5,
"top_k": 0.5,
"candidate_count": 1,
"max_output_tokens": 10,
"stop_sequences": ["stop"],
},
"safety_settings": {6: 3},
"tools": [
b"\n\xad\x01\n\x13get_current_weather\x12+Get the current weather in a given location\x1ai\x08\x06:\x1f\n\x04unit\x12\x17\x08\x01*\x07celsius*\nfahrenheit::\n\x08location\x12.\x08\x01\x1a*The city and state, e.g. San Francisco, CAB\x08location"
],
},
}


def test_from_dict():
with patch("gemini_haystack.generators.gemini.genai.configure"):
gemini = GoogleAIGeminiGenerator.from_dict(
{
"type": "gemini_haystack.generators.gemini.GoogleAIGeminiGenerator",
"init_parameters": {
"model": "gemini-pro-vision",
"generation_config": {
"temperature": 0.5,
"top_p": 0.5,
"top_k": 0.5,
"candidate_count": 1,
"max_output_tokens": 10,
"stop_sequences": ["stop"],
},
"safety_settings": {6: 3},
"tools": [
b"\n\xad\x01\n\x13get_current_weather\x12+Get the current weather in a given location\x1ai\x08\x06:\x1f\n\x04unit\x12\x17\x08\x01*\x07celsius*\nfahrenheit::\n\x08location\x12.\x08\x01\x1a*The city and state, e.g. San Francisco, CAB\x08location"
],
},
}
)

assert gemini._model_name == "gemini-pro-vision"
assert gemini._generation_config == GenerationConfig(
candidate_count=1,
stop_sequences=["stop"],
max_output_tokens=10,
temperature=0.5,
top_p=0.5,
top_k=0.5,
)
assert gemini._safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS: HarmBlockThreshold.BLOCK_ONLY_HIGH}
assert gemini._tools == [
Tool(
function_declarations=[
FunctionDeclaration(
name="get_current_weather",
description="Get the current weather in a given location",
parameters={
"type_": "OBJECT",
"properties": {
"location": {
"type_": "STRING",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {
"type_": "STRING",
"enum": [
"celsius",
"fahrenheit",
],
},
},
"required": ["location"],
},
)
]
)
]
assert isinstance(gemini._model, GenerativeModel)


@pytest.mark.skipif("GOOGLE_API_KEY" not in os.environ, reason="GOOGLE_API_KEY not set")
def test_run():
gemini = GoogleAIGeminiGenerator(model="gemini-pro")
res = gemini.run("Tell me something cool")
assert len(res["answers"]) > 0

0 comments on commit 6f640ad

Please sign in to comment.