diff --git a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py index dd065af4b..20e143ba7 100644 --- a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py +++ b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py @@ -1,16 +1,16 @@ import logging -from typing import Any, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union import google.generativeai as genai from google.ai.generativelanguage import Content, Part from google.ai.generativelanguage import Tool as ToolProto from google.generativeai import GenerationConfig, GenerativeModel -from google.generativeai.types import HarmBlockThreshold, HarmCategory, Tool +from google.generativeai.types import GenerateContentResponse, HarmBlockThreshold, HarmCategory, Tool from haystack.core.component import component from haystack.core.serialization import default_from_dict, default_to_dict -from haystack.dataclasses.byte_stream import ByteStream +from haystack.dataclasses import ByteStream, StreamingChunk from haystack.dataclasses.chat_message import ChatMessage, ChatRole -from haystack.utils import Secret, deserialize_secrets_inplace +from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable logger = logging.getLogger(__name__) @@ -107,6 +107,7 @@ def __init__( generation_config: Optional[Union[GenerationConfig, Dict[str, Any]]] = None, safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None, tools: Optional[List[Tool]] = None, + streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, ): """ Initializes a `GoogleAIGeminiChatGenerator` instance. @@ -132,6 +133,8 @@ def __init__( A dictionary with `HarmCategory` as keys and `HarmBlockThreshold` as values. For more information, see [the API reference](https://ai.google.dev/api) :param tools: A list of Tool objects that can be used for [Function calling](https://ai.google.dev/docs/function_calling). + :param streaming_callback: A callback function that is called when a new token is received from the stream. + The callback function accepts StreamingChunk as an argument. """ genai.configure(api_key=api_key.resolve_value()) @@ -142,6 +145,7 @@ def __init__( self._safety_settings = safety_settings self._tools = tools self._model = GenerativeModel(self._model_name, tools=self._tools) + self._streaming_callback = streaming_callback def _generation_config_to_dict(self, config: Union[GenerationConfig, Dict[str, Any]]) -> Dict[str, Any]: if isinstance(config, dict): @@ -162,6 +166,8 @@ def to_dict(self) -> Dict[str, Any]: :returns: Dictionary with serialized data. """ + callback_name = serialize_callable(self._streaming_callback) if self._streaming_callback else None + data = default_to_dict( self, api_key=self._api_key.to_dict(), @@ -169,6 +175,7 @@ def to_dict(self) -> Dict[str, Any]: generation_config=self._generation_config, safety_settings=self._safety_settings, tools=self._tools, + streaming_callback=callback_name, ) if (tools := data["init_parameters"].get("tools")) is not None: data["init_parameters"]["tools"] = [] @@ -213,6 +220,8 @@ def from_dict(cls, data: Dict[str, Any]) -> "GoogleAIGeminiChatGenerator": data["init_parameters"]["safety_settings"] = { HarmCategory(k): HarmBlockThreshold(v) for k, v in safety_settings.items() } + if (serialized_callback_handler := data["init_parameters"].get("streaming_callback")) is not None: + data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler) return default_from_dict(cls, data) def _convert_part(self, part: Union[str, ByteStream, Part]) -> Part: @@ -274,16 +283,23 @@ def _message_to_content(self, message: ChatMessage) -> Content: return Content(parts=[part], role=role) @component.output_types(replies=List[ChatMessage]) - def run(self, messages: List[ChatMessage]): + def run( + self, + messages: List[ChatMessage], + streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + ): """ Generates text based on the provided messages. :param messages: A list of `ChatMessage` instances, representing the input messages. + :param streaming_callback: + A callback function that is called when a new token is received from the stream. :returns: A dictionary containing the following key: - `replies`: A list containing the generated responses as `ChatMessage` instances. """ + streaming_callback = streaming_callback or self._streaming_callback history = [self._message_to_content(m) for m in messages[:-1]] session = self._model.start_chat(history=history) @@ -292,10 +308,22 @@ def run(self, messages: List[ChatMessage]): content=new_message, generation_config=self._generation_config, safety_settings=self._safety_settings, + stream=streaming_callback is not None, ) + replies = self._get_stream_response(res, streaming_callback) if streaming_callback else self._get_response(res) + + return {"replies": replies} + + def _get_response(self, response_body: GenerateContentResponse) -> List[ChatMessage]: + """ + Extracts the responses from the Google AI response. + + :param response_body: The response from Google AI request. + :returns: The extracted responses. + """ replies = [] - for candidate in res.candidates: + for candidate in response_body.candidates: for part in candidate.content.parts: if part.text != "": replies.append(ChatMessage.from_system(part.text)) @@ -307,5 +335,23 @@ def run(self, messages: List[ChatMessage]): name=part.function_call.name, ) ) + return replies - return {"replies": replies} + def _get_stream_response( + self, stream: GenerateContentResponse, streaming_callback: Callable[[StreamingChunk], None] + ) -> List[ChatMessage]: + """ + Extracts the responses from the Google AI streaming response. + + :param stream: The streaming response from the Google AI request. + :param streaming_callback: The handler for the streaming response. + :returns: The extracted response with the content of all streaming chunks. + """ + responses = [] + for chunk in stream: + content = chunk.text if len(chunk.parts) > 0 and "text" in chunk.parts[0] else "" + streaming_callback(StreamingChunk(content=content, meta=chunk.to_dict())) + responses.append(content) + + combined_response = "".join(responses).lstrip() + return [ChatMessage.from_system(content=combined_response)] diff --git a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/gemini.py b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/gemini.py index 07277e55a..2b0293468 100644 --- a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/gemini.py +++ b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/gemini.py @@ -1,15 +1,15 @@ import logging -from typing import Any, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union import google.generativeai as genai from google.ai.generativelanguage import Content, Part, Tool from google.generativeai import GenerationConfig, GenerativeModel -from google.generativeai.types import HarmBlockThreshold, HarmCategory +from google.generativeai.types import GenerateContentResponse, HarmBlockThreshold, HarmCategory from haystack.core.component import component from haystack.core.component.types import Variadic from haystack.core.serialization import default_from_dict, default_to_dict -from haystack.dataclasses.byte_stream import ByteStream -from haystack.utils import Secret, deserialize_secrets_inplace +from haystack.dataclasses import ByteStream, StreamingChunk +from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable logger = logging.getLogger(__name__) @@ -70,6 +70,7 @@ def __init__( generation_config: Optional[Union[GenerationConfig, Dict[str, Any]]] = None, safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None, tools: Optional[List[Tool]] = None, + streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, ): """ Initializes a `GoogleAIGeminiGenerator` instance. @@ -91,6 +92,8 @@ def __init__( A dictionary with `HarmCategory` as keys and `HarmBlockThreshold` as values. For more information, see [the API reference](https://ai.google.dev/api) :param tools: A list of Tool objects that can be used for [Function calling](https://ai.google.dev/docs/function_calling). + :param streaming_callback: A callback function that is called when a new token is received from the stream. + The callback function accepts StreamingChunk as an argument. """ genai.configure(api_key=api_key.resolve_value()) @@ -100,6 +103,7 @@ def __init__( self._safety_settings = safety_settings self._tools = tools self._model = GenerativeModel(self._model_name, tools=self._tools) + self._streaming_callback = streaming_callback def _generation_config_to_dict(self, config: Union[GenerationConfig, Dict[str, Any]]) -> Dict[str, Any]: if isinstance(config, dict): @@ -120,6 +124,7 @@ def to_dict(self) -> Dict[str, Any]: :returns: Dictionary with serialized data. """ + callback_name = serialize_callable(self._streaming_callback) if self._streaming_callback else None data = default_to_dict( self, api_key=self._api_key.to_dict(), @@ -127,6 +132,7 @@ def to_dict(self) -> Dict[str, Any]: generation_config=self._generation_config, safety_settings=self._safety_settings, tools=self._tools, + streaming_callback=callback_name, ) if (tools := data["init_parameters"].get("tools")) is not None: data["init_parameters"]["tools"] = [Tool.serialize(t) for t in tools] @@ -156,6 +162,8 @@ def from_dict(cls, data: Dict[str, Any]) -> "GoogleAIGeminiGenerator": data["init_parameters"]["safety_settings"] = { HarmCategory(k): HarmBlockThreshold(v) for k, v in safety_settings.items() } + if (serialized_callback_handler := data["init_parameters"].get("streaming_callback")) is not None: + data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler) return default_from_dict(cls, data) @@ -176,28 +184,45 @@ def _convert_part(self, part: Union[str, ByteStream, Part]) -> Part: raise ValueError(msg) @component.output_types(replies=List[Union[str, Dict[str, str]]]) - def run(self, parts: Variadic[Union[str, ByteStream, Part]]): + def run( + self, + parts: Variadic[Union[str, ByteStream, Part]], + streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + ): """ Generates text based on the given input parts. :param parts: A heterogeneous list of strings, `ByteStream` or `Part` objects. + :param streaming_callback: A callback function that is called when a new token is received from the stream. :returns: A dictionary containing the following key: - `replies`: A list of strings or dictionaries with function calls. """ + # check if streaming_callback is passed + streaming_callback = streaming_callback or self._streaming_callback converted_parts = [self._convert_part(p) for p in parts] - contents = [Content(parts=converted_parts, role="user")] res = self._model.generate_content( contents=contents, generation_config=self._generation_config, safety_settings=self._safety_settings, + stream=streaming_callback is not None, ) self._model.start_chat() + replies = self._get_stream_response(res, streaming_callback) if streaming_callback else self._get_response(res) + + return {"replies": replies} + + def _get_response(self, response_body: GenerateContentResponse) -> List[str]: + """ + Extracts the responses from the Google AI request. + :param response_body: The response body from the Google AI request. + :returns: A list of string responses. + """ replies = [] - for candidate in res.candidates: + for candidate in response_body.candidates: for part in candidate.content.parts: if part.text != "": replies.append(part.text) @@ -207,5 +232,23 @@ def run(self, parts: Variadic[Union[str, ByteStream, Part]]): "args": dict(part.function_call.args.items()), } replies.append(function_call) + return replies - return {"replies": replies} + def _get_stream_response( + self, stream: GenerateContentResponse, streaming_callback: Callable[[StreamingChunk], None] + ) -> List[str]: + """ + Extracts the responses from the Google AI streaming response. + :param stream: The streaming response from the Google AI request. + :param streaming_callback: The handler for the streaming response. + :returns: A list of string responses. + """ + + responses = [] + for chunk in stream: + content = chunk.text if len(chunk.parts) > 0 and "text" in chunk.parts[0] else "" + streaming_callback(StreamingChunk(content=content, meta=chunk.to_dict())) + responses.append(content) + + combined_response = ["".join(responses).lstrip()] + return combined_response diff --git a/integrations/google_ai/tests/generators/chat/test_chat_gemini.py b/integrations/google_ai/tests/generators/chat/test_chat_gemini.py index 9b3124eab..0302a3da7 100644 --- a/integrations/google_ai/tests/generators/chat/test_chat_gemini.py +++ b/integrations/google_ai/tests/generators/chat/test_chat_gemini.py @@ -4,10 +4,30 @@ import pytest from google.generativeai import GenerationConfig, GenerativeModel from google.generativeai.types import FunctionDeclaration, HarmBlockThreshold, HarmCategory, Tool +from haystack.dataclasses import StreamingChunk from haystack.dataclasses.chat_message import ChatMessage from haystack_integrations.components.generators.google_ai import GoogleAIGeminiChatGenerator +GET_CURRENT_WEATHER_FUNC = FunctionDeclaration( + name="get_current_weather", + description="Get the current weather in a given location", + parameters={ + "type_": "OBJECT", + "properties": { + "location": {"type_": "STRING", "description": "The city and state, e.g. San Francisco, CA"}, + "unit": { + "type_": "STRING", + "enum": [ + "celsius", + "fahrenheit", + ], + }, + }, + "required": ["location"], + }, +) + def test_init(monkeypatch): monkeypatch.setenv("GOOGLE_API_KEY", "test") @@ -21,26 +41,7 @@ def test_init(monkeypatch): top_k=0.5, ) safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} - get_current_weather_func = FunctionDeclaration( - name="get_current_weather", - description="Get the current weather in a given location", - parameters={ - "type_": "OBJECT", - "properties": { - "location": {"type_": "STRING", "description": "The city and state, e.g. San Francisco, CA"}, - "unit": { - "type_": "STRING", - "enum": [ - "celsius", - "fahrenheit", - ], - }, - }, - "required": ["location"], - }, - ) - - tool = Tool(function_declarations=[get_current_weather_func]) + tool = Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC]) with patch( "haystack_integrations.components.generators.google_ai.chat.gemini.genai.configure" ) as mock_genai_configure: @@ -60,6 +61,24 @@ def test_init(monkeypatch): def test_to_dict(monkeypatch): monkeypatch.setenv("GOOGLE_API_KEY", "test") + with patch("haystack_integrations.components.generators.google_ai.chat.gemini.genai.configure"): + gemini = GoogleAIGeminiChatGenerator() + assert gemini.to_dict() == { + "type": "haystack_integrations.components.generators.google_ai.chat.gemini.GoogleAIGeminiChatGenerator", + "init_parameters": { + "api_key": {"env_vars": ["GOOGLE_API_KEY"], "strict": True, "type": "env_var"}, + "model": "gemini-pro-vision", + "generation_config": None, + "safety_settings": None, + "streaming_callback": None, + "tools": None, + }, + } + + +def test_to_dict_with_param(monkeypatch): + monkeypatch.setenv("GOOGLE_API_KEY", "test") + generation_config = GenerationConfig( candidate_count=1, stop_sequences=["stop"], @@ -69,26 +88,7 @@ def test_to_dict(monkeypatch): top_k=2, ) safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} - get_current_weather_func = FunctionDeclaration( - name="get_current_weather", - description="Get the current weather in a given location", - parameters={ - "type_": "OBJECT", - "properties": { - "location": {"type_": "STRING", "description": "The city and state, e.g. San Francisco, CA"}, - "unit": { - "type_": "STRING", - "enum": [ - "celsius", - "fahrenheit", - ], - }, - }, - "required": ["location"], - }, - ) - - tool = Tool(function_declarations=[get_current_weather_func]) + tool = Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC]) with patch("haystack_integrations.components.generators.google_ai.chat.gemini.genai.configure"): gemini = GoogleAIGeminiChatGenerator( @@ -110,6 +110,7 @@ def test_to_dict(monkeypatch): "stop_sequences": ["stop"], }, "safety_settings": {10: 3}, + "streaming_callback": None, "tools": [ b"\n\xad\x01\n\x13get_current_weather\x12+Get the current weather in a given location\x1ai" b"\x08\x06:\x1f\n\x04unit\x12\x17\x08\x01*\x07celsius*\nfahrenheit::\n\x08location\x12.\x08" @@ -122,6 +123,31 @@ def test_to_dict(monkeypatch): def test_from_dict(monkeypatch): monkeypatch.setenv("GOOGLE_API_KEY", "test") + with patch("haystack_integrations.components.generators.google_ai.chat.gemini.genai.configure"): + gemini = GoogleAIGeminiChatGenerator.from_dict( + { + "type": "haystack_integrations.components.generators.google_ai.chat.gemini.GoogleAIGeminiChatGenerator", + "init_parameters": { + "api_key": {"env_vars": ["GOOGLE_API_KEY"], "strict": True, "type": "env_var"}, + "model": "gemini-pro-vision", + "generation_config": None, + "safety_settings": None, + "streaming_callback": None, + "tools": None, + }, + } + ) + + assert gemini._model_name == "gemini-pro-vision" + assert gemini._generation_config is None + assert gemini._safety_settings is None + assert gemini._tools is None + assert isinstance(gemini._model, GenerativeModel) + + +def test_from_dict_with_param(monkeypatch): + monkeypatch.setenv("GOOGLE_API_KEY", "test") + with patch("haystack_integrations.components.generators.google_ai.chat.gemini.genai.configure"): gemini = GoogleAIGeminiChatGenerator.from_dict( { @@ -138,6 +164,7 @@ def test_from_dict(monkeypatch): "stop_sequences": ["stop"], }, "safety_settings": {10: 3}, + "streaming_callback": None, "tools": [ b"\n\xad\x01\n\x13get_current_weather\x12+Get the current weather in a given location\x1ai" b"\x08\x06:\x1f\n\x04unit\x12\x17\x08\x01*\x07celsius*\nfahrenheit::\n\x08location\x12.\x08" @@ -198,6 +225,33 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001 assert len(res["replies"]) > 0 +@pytest.mark.skipif(not os.environ.get("GOOGLE_API_KEY", None), reason="GOOGLE_API_KEY env var not set") +def test_run_with_streaming_callback(): + streaming_callback_called = False + + def streaming_callback(_chunk: StreamingChunk) -> None: + nonlocal streaming_callback_called + streaming_callback_called = True + + def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001 + return {"weather": "sunny", "temperature": 21.8, "unit": unit} + + get_current_weather_func = FunctionDeclaration.from_function( + get_current_weather, + descriptions={ + "location": "The city and state, e.g. San Francisco, CA", + "unit": "The temperature unit of measurement, e.g. celsius or fahrenheit", + }, + ) + + tool = Tool(function_declarations=[get_current_weather_func]) + gemini_chat = GoogleAIGeminiChatGenerator(model="gemini-pro", tools=[tool], streaming_callback=streaming_callback) + messages = [ChatMessage.from_user(content="What is the temperature in celsius in Berlin?")] + res = gemini_chat.run(messages=messages) + assert len(res["replies"]) > 0 + assert streaming_callback_called + + @pytest.mark.skipif(not os.environ.get("GOOGLE_API_KEY", None), reason="GOOGLE_API_KEY env var not set") def test_past_conversation(): gemini_chat = GoogleAIGeminiChatGenerator(model="gemini-pro") diff --git a/integrations/google_ai/tests/generators/test_gemini.py b/integrations/google_ai/tests/generators/test_gemini.py index 35c7d196b..fe2e56e67 100644 --- a/integrations/google_ai/tests/generators/test_gemini.py +++ b/integrations/google_ai/tests/generators/test_gemini.py @@ -5,9 +5,29 @@ from google.ai.generativelanguage import FunctionDeclaration, Tool from google.generativeai import GenerationConfig, GenerativeModel from google.generativeai.types import HarmBlockThreshold, HarmCategory +from haystack.dataclasses import StreamingChunk from haystack_integrations.components.generators.google_ai import GoogleAIGeminiGenerator +GET_CURRENT_WEATHER_FUNC = FunctionDeclaration( + name="get_current_weather", + description="Get the current weather in a given location", + parameters={ + "type_": "OBJECT", + "properties": { + "location": {"type_": "STRING", "description": "The city and state, e.g. San Francisco, CA"}, + "unit": { + "type_": "STRING", + "enum": [ + "celsius", + "fahrenheit", + ], + }, + }, + "required": ["location"], + }, +) + def test_init(monkeypatch): monkeypatch.setenv("GOOGLE_API_KEY", "test") @@ -58,6 +78,24 @@ def test_init(monkeypatch): def test_to_dict(monkeypatch): monkeypatch.setenv("GOOGLE_API_KEY", "test") + with patch("haystack_integrations.components.generators.google_ai.gemini.genai.configure"): + gemini = GoogleAIGeminiGenerator() + assert gemini.to_dict() == { + "type": "haystack_integrations.components.generators.google_ai.gemini.GoogleAIGeminiGenerator", + "init_parameters": { + "model": "gemini-pro-vision", + "api_key": {"env_vars": ["GOOGLE_API_KEY"], "strict": True, "type": "env_var"}, + "generation_config": None, + "safety_settings": None, + "streaming_callback": None, + "tools": None, + }, + } + + +def test_to_dict_with_param(monkeypatch): + monkeypatch.setenv("GOOGLE_API_KEY", "test") + generation_config = GenerationConfig( candidate_count=1, stop_sequences=["stop"], @@ -108,6 +146,7 @@ def test_to_dict(monkeypatch): "stop_sequences": ["stop"], }, "safety_settings": {10: 3}, + "streaming_callback": None, "tools": [ b"\n\xad\x01\n\x13get_current_weather\x12+Get the current weather in a given location\x1ai" b"\x08\x06:\x1f\n\x04unit\x12\x17\x08\x01*\x07celsius*\nfahrenheit::\n\x08location\x12.\x08" @@ -117,7 +156,7 @@ def test_to_dict(monkeypatch): } -def test_from_dict(monkeypatch): +def test_from_dict_with_param(monkeypatch): monkeypatch.setenv("GOOGLE_API_KEY", "test") with patch("haystack_integrations.components.generators.google_ai.gemini.genai.configure"): @@ -135,6 +174,7 @@ def test_from_dict(monkeypatch): "stop_sequences": ["stop"], }, "safety_settings": {10: 3}, + "streaming_callback": None, "tools": [ b"\n\xad\x01\n\x13get_current_weather\x12+Get the current weather in a given location\x1ai" b"\x08\x06:\x1f\n\x04unit\x12\x17\x08\x01*\x07celsius*\nfahrenheit::\n\x08location\x12.\x08" @@ -154,33 +194,49 @@ def test_from_dict(monkeypatch): top_k=0.5, ) assert gemini._safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} - assert gemini._tools == [ - Tool( - function_declarations=[ - FunctionDeclaration( - name="get_current_weather", - description="Get the current weather in a given location", - parameters={ - "type_": "OBJECT", - "properties": { - "location": { - "type_": "STRING", - "description": "The city and state, e.g. San Francisco, CA", - }, - "unit": { - "type_": "STRING", - "enum": [ - "celsius", - "fahrenheit", - ], - }, - }, - "required": ["location"], + assert gemini._tools == [Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC])] + assert isinstance(gemini._model, GenerativeModel) + + +def test_from_dict(monkeypatch): + monkeypatch.setenv("GOOGLE_API_KEY", "test") + + with patch("haystack_integrations.components.generators.google_ai.gemini.genai.configure"): + gemini = GoogleAIGeminiGenerator.from_dict( + { + "type": "haystack_integrations.components.generators.google_ai.gemini.GoogleAIGeminiGenerator", + "init_parameters": { + "model": "gemini-pro-vision", + "generation_config": { + "temperature": 0.5, + "top_p": 0.5, + "top_k": 0.5, + "candidate_count": 1, + "max_output_tokens": 10, + "stop_sequences": ["stop"], }, - ) - ] + "safety_settings": {10: 3}, + "streaming_callback": None, + "tools": [ + b"\n\xad\x01\n\x13get_current_weather\x12+Get the current weather in a given location\x1ai" + b"\x08\x06:\x1f\n\x04unit\x12\x17\x08\x01*\x07celsius*\nfahrenheit::\n\x08location\x12.\x08" + b"\x01\x1a*The city and state, e.g. San Francisco, CAB\x08location" + ], + }, + } ) - ] + + assert gemini._model_name == "gemini-pro-vision" + assert gemini._generation_config == GenerationConfig( + candidate_count=1, + stop_sequences=["stop"], + max_output_tokens=10, + temperature=0.5, + top_p=0.5, + top_k=0.5, + ) + assert gemini._safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} + assert gemini._tools == [Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC])] assert isinstance(gemini._model, GenerativeModel) @@ -189,3 +245,17 @@ def test_run(): gemini = GoogleAIGeminiGenerator(model="gemini-pro") res = gemini.run("Tell me something cool") assert len(res["replies"]) > 0 + + +@pytest.mark.skipif(not os.environ.get("GOOGLE_API_KEY", None), reason="GOOGLE_API_KEY env var not set") +def test_run_with_streaming_callback(): + streaming_callback_called = False + + def streaming_callback(_chunk: StreamingChunk) -> None: + nonlocal streaming_callback_called + streaming_callback_called = True + + gemini = GoogleAIGeminiGenerator(model="gemini-pro", streaming_callback=streaming_callback) + res = gemini.run("Tell me something cool") + assert len(res["replies"]) > 0 + assert streaming_callback_called