diff --git a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/gemini.py b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/gemini.py index b032169df..fd3ddd7bf 100644 --- a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/gemini.py +++ b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/gemini.py @@ -101,11 +101,11 @@ def __init__( genai.configure(api_key=api_key.resolve_value()) self._api_key = api_key - self._model_name = model - self._generation_config = generation_config - self._safety_settings = safety_settings - self._model = GenerativeModel(self._model_name) - self._streaming_callback = streaming_callback + self.model_name = model + self.generation_config = generation_config + self.safety_settings = safety_settings + self._model = GenerativeModel(self.model_name) + self.streaming_callback = streaming_callback def _generation_config_to_dict(self, config: Union[GenerationConfig, Dict[str, Any]]) -> Dict[str, Any]: if isinstance(config, dict): @@ -126,13 +126,13 @@ def to_dict(self) -> Dict[str, Any]: :returns: Dictionary with serialized data. """ - callback_name = serialize_callable(self._streaming_callback) if self._streaming_callback else None + callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None data = default_to_dict( self, api_key=self._api_key.to_dict(), - model=self._model_name, - generation_config=self._generation_config, - safety_settings=self._safety_settings, + model=self.model_name, + generation_config=self.generation_config, + safety_settings=self.safety_settings, streaming_callback=callback_name, ) if (generation_config := data["init_parameters"].get("generation_config")) is not None: @@ -198,13 +198,13 @@ def run( """ # check if streaming_callback is passed - streaming_callback = streaming_callback or self._streaming_callback + streaming_callback = streaming_callback or self.streaming_callback converted_parts = [self._convert_part(p) for p in parts] contents = [Content(parts=converted_parts, role="user")] res = self._model.generate_content( contents=contents, - generation_config=self._generation_config, - safety_settings=self._safety_settings, + generation_config=self.generation_config, + safety_settings=self.safety_settings, stream=streaming_callback is not None, ) self._model.start_chat() diff --git a/integrations/google_ai/tests/generators/test_gemini.py b/integrations/google_ai/tests/generators/test_gemini.py index 07d194a59..eb09514eb 100644 --- a/integrations/google_ai/tests/generators/test_gemini.py +++ b/integrations/google_ai/tests/generators/test_gemini.py @@ -18,7 +18,7 @@ def test_init(monkeypatch): max_output_tokens=10, temperature=0.5, top_p=0.5, - top_k=0.5, + top_k=1, ) safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} @@ -28,9 +28,9 @@ def test_init(monkeypatch): safety_settings=safety_settings, ) mock_genai_configure.assert_called_once_with(api_key="test") - assert gemini._model_name == "gemini-1.5-flash" - assert gemini._generation_config == generation_config - assert gemini._safety_settings == safety_settings + assert gemini.model_name == "gemini-1.5-flash" + assert gemini.generation_config == generation_config + assert gemini.safety_settings == safety_settings assert isinstance(gemini._model, GenerativeModel) @@ -105,7 +105,7 @@ def test_from_dict_with_param(monkeypatch): "generation_config": { "temperature": 0.5, "top_p": 0.5, - "top_k": 0.5, + "top_k": 1, "candidate_count": 1, "max_output_tokens": 10, "stop_sequences": ["stop"], @@ -116,16 +116,16 @@ def test_from_dict_with_param(monkeypatch): } ) - assert gemini._model_name == "gemini-1.5-flash" - assert gemini._generation_config == GenerationConfig( + assert gemini.model_name == "gemini-1.5-flash" + assert gemini.generation_config == GenerationConfig( candidate_count=1, stop_sequences=["stop"], max_output_tokens=10, temperature=0.5, top_p=0.5, - top_k=0.5, + top_k=1, ) - assert gemini._safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} + assert gemini.safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} assert isinstance(gemini._model, GenerativeModel) @@ -141,7 +141,7 @@ def test_from_dict(monkeypatch): "generation_config": { "temperature": 0.5, "top_p": 0.5, - "top_k": 0.5, + "top_k": 1, "candidate_count": 1, "max_output_tokens": 10, "stop_sequences": ["stop"], @@ -152,16 +152,16 @@ def test_from_dict(monkeypatch): } ) - assert gemini._model_name == "gemini-1.5-flash" - assert gemini._generation_config == GenerationConfig( + assert gemini.model_name == "gemini-1.5-flash" + assert gemini.generation_config == GenerationConfig( candidate_count=1, stop_sequences=["stop"], max_output_tokens=10, temperature=0.5, top_p=0.5, - top_k=0.5, + top_k=1, ) - assert gemini._safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} + assert gemini.safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} assert isinstance(gemini._model, GenerativeModel)