From 03577e4f65ce0f4724a68d4ee1f78dbd68df79e7 Mon Sep 17 00:00:00 2001 From: anakin87 Date: Mon, 4 Mar 2024 18:03:40 +0100 Subject: [PATCH] wip --- .../generators/google_ai/chat/gemini.py | 66 +++++++++++-------- .../components/generators/google_ai/gemini.py | 63 +++++++++++------- 2 files changed, 77 insertions(+), 52 deletions(-) diff --git a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py index 3c84b4081..c3b486e7d 100644 --- a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py +++ b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py @@ -17,9 +17,10 @@ @component class GoogleAIGeminiChatGenerator: """ - GoogleAIGeminiGenerator is a multi modal generator supporting Gemini via Google Makersuite. + `GoogleAIGeminiChatGenerator` is a multimodal generator supporting Gemini via Google AI Studio. + It uses the `ChatMessage` dataclass to interact with the model. - Sample usage: + Usage example: ```python from haystack.utils import Secret from haystack.dataclasses.chat_message import ChatMessage @@ -40,7 +41,7 @@ class GoogleAIGeminiChatGenerator: ``` - This is a more advanced usage that also uses function calls: + Usage example with function calling: ```python from haystack.utils import Secret from haystack.dataclasses.chat_message import ChatMessage @@ -53,7 +54,7 @@ def get_current_weather(location: str, unit: str = "celsius") -> str: # Call a weather API and return some text ... - # Define the function interface so that Gemini can call it + # Define the function interface get_current_weather_func = FunctionDeclaration( name="get_current_weather", description="Get the current weather in a given location", @@ -88,12 +89,6 @@ def get_current_weather(location: str, unit: str = "celsius") -> str: for reply in res["replies"]: print(reply.content) ``` - - Input: - - **messages** A list of ChatMessage objects. - - Output: - - **replies** A list of ChatMessage objects containing the one or more replies from the model. """ def __init__( @@ -106,7 +101,7 @@ def __init__( tools: Optional[List[Tool]] = None, ): """ - Initialize a GoogleAIGeminiChatGenerator instance. + Initialize a `GoogleAIGeminiChatGenerator` instance. To get an API key, visit: https://makersuite.google.com @@ -115,24 +110,18 @@ def __init__( * `gemini-pro-vision` * `gemini-ultra` - :param api_key: Google Makersuite API key. - :param model: Name of the model to use, defaults to "gemini-pro-vision" - :param generation_config: The generation config to use, defaults to None. - Can either be a GenerationConfig object or a dictionary of parameters. - Accepted parameters are: - - temperature - - top_p - - top_k - - candidate_count - - max_output_tokens - - stop_sequences - :param safety_settings: The safety settings to use, defaults to None. - A dictionary of HarmCategory to HarmBlockThreshold. - :param tools: The tools to use, defaults to None. - A list of Tool objects that can be used to modify the generation process. + :param api_key: Google AI Studio API key. + :param model: Name of the model to use. + :param generation_config: The generation config to use. + Can either be a `GenerationConfig` object or a dictionary of parameters. + For the available parameters, see + [the `GenerationConfig` API reference](https://ai.google.dev/api/python/google/generativeai/GenerationConfig). + :param safety_settings: The safety settings to use. + A dictionary with `HarmCategory` as keys and `HarmBlockThreshold` as values. + For more information, see [the API reference](https://ai.google.dev/api) + :param tools: A list of Tool objects that can be used for [Function calling](https://ai.google.dev/docs/function_calling). """ - # Authenticate, if api_key is None it will use the GOOGLE_API_KEY env variable genai.configure(api_key=api_key.resolve_value()) self._api_key = api_key @@ -155,6 +144,12 @@ def _generation_config_to_dict(self, config: Union[GenerationConfig, Dict[str, A } def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ data = default_to_dict( self, api_key=self._api_key.to_dict(), @@ -173,6 +168,14 @@ def to_dict(self) -> Dict[str, Any]: @classmethod def from_dict(cls, data: Dict[str, Any]) -> "GoogleAIGeminiChatGenerator": + """ + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + :returns: + Deserialized component. + """ deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"]) if (tools := data["init_parameters"].get("tools")) is not None: @@ -247,6 +250,15 @@ def _message_to_content(self, message: ChatMessage) -> Content: @component.output_types(replies=List[ChatMessage]) def run(self, messages: List[ChatMessage]): + """ + Generate text based on the provided messages. + + :param messages: + A list of `ChatMessage` instances, representing the input messages. + :returns: + A dictionary containing the following key: + - `replies`: A list containing the generated responses as `ChatMessage` instances. + """ history = [self._message_to_content(m) for m in messages[:-1]] session = self._model.start_chat(history=history) diff --git a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/gemini.py b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/gemini.py index 3929e7d5e..4baf0d1b1 100644 --- a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/gemini.py +++ b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/gemini.py @@ -17,9 +17,9 @@ @component class GoogleAIGeminiGenerator: """ - GoogleAIGeminiGenerator is a multi modal generator supporting Gemini via Google Makersuite. + `GoogleAIGeminiGenerator` is a multimodal generator supporting Gemini via Google AI Studio. - Sample usage: + Usage example: ```python from haystack.utils import Secret from haystack_integrations.components.generators.google_ai import GoogleAIGeminiGenerator @@ -30,7 +30,7 @@ class GoogleAIGeminiGenerator: print(answer) ``` - This is a more advanced usage that also uses text and images as input: + Multimodal usage example: ```python import requests from haystack.utils import Secret @@ -58,12 +58,6 @@ class GoogleAIGeminiGenerator: for answer in result["answers"]: print(answer) ``` - - Input: - - **parts** A eterogeneous list of strings, ByteStream or Part objects. - - Output: - - **answers** A list of strings or dictionaries with function calls. """ def __init__( @@ -76,7 +70,7 @@ def __init__( tools: Optional[List[Tool]] = None, ): """ - Initialize a GoogleAIGeminiGenerator instance. + Initialize a `GoogleAIGeminiGenerator` instance. To get an API key, visit: https://makersuite.google.com @@ -85,21 +79,16 @@ def __init__( * `gemini-pro-vision` * `gemini-ultra` - :param api_key: Google Makersuite API key. - :param model: Name of the model to use, defaults to "gemini-pro-vision" - :param generation_config: The generation config to use, defaults to None. - Can either be a GenerationConfig object or a dictionary of parameters. - Accepted parameters are: - - temperature - - top_p - - top_k - - candidate_count - - max_output_tokens - - stop_sequences - :param safety_settings: The safety settings to use, defaults to None. - A dictionary of HarmCategory to HarmBlockThreshold. - :param tools: The tools to use, defaults to None. - A list of Tool objects that can be used to modify the generation process. + :param api_key: Google AI Studio API key. + :param model: Name of the model to use. + :param generation_config: The generation config to use. + Can either be a `GenerationConfig` object or a dictionary of parameters. + For the available parameters, see + [the `GenerationConfig` API reference](https://ai.google.dev/api/python/google/generativeai/GenerationConfig). + :param safety_settings: The safety settings to use. + A dictionary with `HarmCategory` as keys and `HarmBlockThreshold` as values. + For more information, see [the API reference](https://ai.google.dev/api) + :param tools: A list of Tool objects that can be used for [Function calling](https://ai.google.dev/docs/function_calling). """ genai.configure(api_key=api_key.resolve_value()) @@ -123,6 +112,12 @@ def _generation_config_to_dict(self, config: Union[GenerationConfig, Dict[str, A } def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ data = default_to_dict( self, api_key=self._api_key.to_dict(), @@ -141,6 +136,14 @@ def to_dict(self) -> Dict[str, Any]: @classmethod def from_dict(cls, data: Dict[str, Any]) -> "GoogleAIGeminiGenerator": + """ + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + :returns: + Deserialized component. + """ deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"]) if (tools := data["init_parameters"].get("tools")) is not None: @@ -172,6 +175,16 @@ def _convert_part(self, part: Union[str, ByteStream, Part]) -> Part: @component.output_types(answers=List[Union[str, Dict[str, str]]]) def run(self, parts: Variadic[Union[str, ByteStream, Part]]): + """ + Generate text based on the given input parts. + + :param parts: + A heterogeneous list of strings, `ByteStream` or `Part` objects. + :returns: + A dictionary containing the following key: + - `answers`: A list of strings or dictionaries with function calls. + """ + converted_parts = [self._convert_part(p) for p in parts] contents = [Content(parts=converted_parts, role="user")]