Skip to content

Commit

Permalink
Testing serialization on all python versions
Browse files Browse the repository at this point in the history
  • Loading branch information
Amnah199 committed Sep 4, 2024
1 parent dcb0d3f commit de84f92
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 23 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/google_vertex.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-latest, windows-latest, macos-latest]
python-version: ["3.9", "3.10"]
python-version: ["3.9", "3.10". "3.11", "3.12"]

steps:
- name: Support longpaths
Expand Down
62 changes: 40 additions & 22 deletions integrations/google_vertex/tests/chat/test_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,28 +116,45 @@ def test_to_dict_with_params(_mock_vertexai_init, _mock_generative_model):
tools=[tool],
)

serialized = gemini.to_dict()
assert (
serialized["type"]
== "haystack_integrations.components.generators.google_vertex.chat.gemini.VertexAIGeminiChatGenerator"
)
assert serialized["init_parameters"]["model"] == "gemini-1.5-flash"
assert serialized["init_parameters"]["project_id"] == "TestID123"
assert serialized["init_parameters"]["location"] is None
assert serialized["init_parameters"]["generation_config"] == {
"temperature": 0.5,
"top_p": 0.5,
"top_k": 2.0,
"candidate_count": 1,
"max_output_tokens": 10,
"stop_sequences": ["stop"],
}
assert serialized["init_parameters"]["safety_settings"] == {
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH
assert gemini.to_dict() == {
"type": "haystack_integrations.components.generators.google_vertex.chat.gemini.VertexAIGeminiChatGenerator",
"init_parameters": {
"model": "gemini-1.5-flash",
"project_id": "TestID123",
"location": None,
"generation_config": {
"temperature": 0.5,
"top_p": 0.5,
"top_k": 2.0,
"candidate_count": 1,
"max_output_tokens": 10,
"stop_sequences": ["stop"],
},
"safety_settings": {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH},
"streaming_callback": None,
"tools": [
{
"function_declarations": [
{
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type_": "OBJECT",
"properties": {
"location": {
"type_": "STRING",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type_": "STRING", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
}
]
}
],
},
}
assert serialized["init_parameters"]["streaming_callback"] is None
assert len(serialized["init_parameters"]["tools"]) == 1
assert serialized["init_parameters"]["tools"][0]["function_declarations"][0] == GET_CURRENT_WEATHER_FUNC.to_dict()


@patch("haystack_integrations.components.generators.google_vertex.chat.gemini.vertexai_init")
Expand Down Expand Up @@ -211,8 +228,9 @@ def test_from_dict_with_param(_mock_vertexai_init, _mock_generative_model):
assert gemini._model_name == "gemini-1.5-flash"
assert gemini._project_id == "TestID123"
assert gemini._safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH}
assert repr(gemini._tools) == repr([Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC])])
# assert repr(gemini._tools) == repr([Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC])])
assert isinstance(gemini._generation_config, GenerationConfig)
assert isinstance(gemini._tools[0], Tool)


@patch("haystack_integrations.components.generators.google_vertex.chat.gemini.GenerativeModel")
Expand Down

0 comments on commit de84f92

Please sign in to comment.