From 6702bed8bfe1124360a86398c34051a842d0bec1 Mon Sep 17 00:00:00 2001 From: Leonid Kuligin Date: Wed, 6 Mar 2024 20:09:35 +0100 Subject: [PATCH] Temp fix (#51) * Fixed temp=0 --- .../langchain_google_vertexai/llms.py | 2 +- libs/vertexai/tests/unit_tests/test_llm.py | 49 +++++++++++++++++++ 2 files changed, 50 insertions(+), 1 deletion(-) create mode 100644 libs/vertexai/tests/unit_tests/test_llm.py diff --git a/libs/vertexai/langchain_google_vertexai/llms.py b/libs/vertexai/langchain_google_vertexai/llms.py index 59fb147b..ca16de4a 100644 --- a/libs/vertexai/langchain_google_vertexai/llms.py +++ b/libs/vertexai/langchain_google_vertexai/llms.py @@ -217,7 +217,7 @@ def _default_params(self) -> Dict[str, Any]: default_value = default_params.get(param_name) if param_value or default_value: updated_params[param_name] = ( - param_value if param_value else default_value + param_value if param_value is not None else default_value ) return updated_params diff --git a/libs/vertexai/tests/unit_tests/test_llm.py b/libs/vertexai/tests/unit_tests/test_llm.py new file mode 100644 index 00000000..8d65221e --- /dev/null +++ b/libs/vertexai/tests/unit_tests/test_llm.py @@ -0,0 +1,49 @@ +from typing import Any, Dict +from unittest import TestCase +from unittest.mock import MagicMock, patch + +from langchain_google_vertexai.llms import VertexAI + + +def test_vertexai_args_passed() -> None: + response_text = "Goodbye" + user_prompt = "Hello" + prompt_params: Dict[str, Any] = { + "max_output_tokens": 1, + "temperature": 0, + "top_k": 10, + "top_p": 0.5, + } + + # Mock the library to ensure the args are passed correctly + with patch("langchain_google_vertexai.llms.GenerativeModel") as model: + with patch("langchain_google_vertexai.llms.get_generation_info") as gen_info: + gen_info.return_value = {} + mock_response = MagicMock() + candidate = MagicMock() + candidate.text = response_text + mock_response.candidates = [candidate] + model_instance = MagicMock() + model_instance.generate_content.return_value = mock_response + model.return_value = model_instance + + llm = VertexAI(model_name="gemini-pro", **prompt_params) + response = llm.invoke("Hello") + assert response == response_text + model_instance.generate_content.assert_called_once + + assert model_instance.generate_content.call_args.args[0] == [user_prompt] + TestCase().assertCountEqual( + model_instance.generate_content.call_args.kwargs, + { + "stream": False, + "safety_settings": None, + "generation_config": { + "max_output_tokens": 1, + "temperature": 0, + "top_k": 10, + "top_p": 0.5, + "stop_sequences": None, + }, + }, + )