diff --git a/.github/workflows/google_vertex.yml b/.github/workflows/google_vertex.yml index 78ba5694b..cd82f0160 100644 --- a/.github/workflows/google_vertex.yml +++ b/.github/workflows/google_vertex.yml @@ -30,7 +30,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest, windows-latest, macos-latest] - python-version: ["3.9", "3.10"] + python-version: ["3.9", "3.10". "3.11", "3.12"] steps: - name: Support longpaths diff --git a/integrations/google_vertex/tests/chat/test_gemini.py b/integrations/google_vertex/tests/chat/test_gemini.py index 810028fac..0aaa4ac23 100644 --- a/integrations/google_vertex/tests/chat/test_gemini.py +++ b/integrations/google_vertex/tests/chat/test_gemini.py @@ -116,28 +116,45 @@ def test_to_dict_with_params(_mock_vertexai_init, _mock_generative_model): tools=[tool], ) - serialized = gemini.to_dict() - assert ( - serialized["type"] - == "haystack_integrations.components.generators.google_vertex.chat.gemini.VertexAIGeminiChatGenerator" - ) - assert serialized["init_parameters"]["model"] == "gemini-1.5-flash" - assert serialized["init_parameters"]["project_id"] == "TestID123" - assert serialized["init_parameters"]["location"] is None - assert serialized["init_parameters"]["generation_config"] == { - "temperature": 0.5, - "top_p": 0.5, - "top_k": 2.0, - "candidate_count": 1, - "max_output_tokens": 10, - "stop_sequences": ["stop"], - } - assert serialized["init_parameters"]["safety_settings"] == { - HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH + assert gemini.to_dict() == { + "type": "haystack_integrations.components.generators.google_vertex.chat.gemini.VertexAIGeminiChatGenerator", + "init_parameters": { + "model": "gemini-1.5-flash", + "project_id": "TestID123", + "location": None, + "generation_config": { + "temperature": 0.5, + "top_p": 0.5, + "top_k": 2.0, + "candidate_count": 1, + "max_output_tokens": 10, + "stop_sequences": ["stop"], + }, + "safety_settings": {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH}, + "streaming_callback": None, + "tools": [ + { + "function_declarations": [ + { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type_": "OBJECT", + "properties": { + "location": { + "type_": "STRING", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type_": "STRING", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + } + ] + } + ], + }, } - assert serialized["init_parameters"]["streaming_callback"] is None - assert len(serialized["init_parameters"]["tools"]) == 1 - assert serialized["init_parameters"]["tools"][0]["function_declarations"][0] == GET_CURRENT_WEATHER_FUNC.to_dict() @patch("haystack_integrations.components.generators.google_vertex.chat.gemini.vertexai_init") @@ -211,8 +228,9 @@ def test_from_dict_with_param(_mock_vertexai_init, _mock_generative_model): assert gemini._model_name == "gemini-1.5-flash" assert gemini._project_id == "TestID123" assert gemini._safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} - assert repr(gemini._tools) == repr([Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC])]) + # assert repr(gemini._tools) == repr([Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC])]) assert isinstance(gemini._generation_config, GenerationConfig) + assert isinstance(gemini._tools[0], Tool) @patch("haystack_integrations.components.generators.google_vertex.chat.gemini.GenerativeModel")