Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
anakin87 committed Mar 4, 2024
1 parent de56507 commit 03577e4
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@
@component
class GoogleAIGeminiChatGenerator:
"""
GoogleAIGeminiGenerator is a multi modal generator supporting Gemini via Google Makersuite.
`GoogleAIGeminiChatGenerator` is a multimodal generator supporting Gemini via Google AI Studio.
It uses the `ChatMessage` dataclass to interact with the model.
Sample usage:
Usage example:
```python
from haystack.utils import Secret
from haystack.dataclasses.chat_message import ChatMessage
Expand All @@ -40,7 +41,7 @@ class GoogleAIGeminiChatGenerator:
```
This is a more advanced usage that also uses function calls:
Usage example with function calling:
```python
from haystack.utils import Secret
from haystack.dataclasses.chat_message import ChatMessage
Expand All @@ -53,7 +54,7 @@ def get_current_weather(location: str, unit: str = "celsius") -> str:
# Call a weather API and return some text
...
# Define the function interface so that Gemini can call it
# Define the function interface
get_current_weather_func = FunctionDeclaration(
name="get_current_weather",
description="Get the current weather in a given location",
Expand Down Expand Up @@ -88,12 +89,6 @@ def get_current_weather(location: str, unit: str = "celsius") -> str:
for reply in res["replies"]:
print(reply.content)
```
Input:
- **messages** A list of ChatMessage objects.
Output:
- **replies** A list of ChatMessage objects containing the one or more replies from the model.
"""

def __init__(
Expand All @@ -106,7 +101,7 @@ def __init__(
tools: Optional[List[Tool]] = None,
):
"""
Initialize a GoogleAIGeminiChatGenerator instance.
Initialize a `GoogleAIGeminiChatGenerator` instance.
To get an API key, visit: https://makersuite.google.com
Expand All @@ -115,24 +110,18 @@ def __init__(
* `gemini-pro-vision`
* `gemini-ultra`
:param api_key: Google Makersuite API key.
:param model: Name of the model to use, defaults to "gemini-pro-vision"
:param generation_config: The generation config to use, defaults to None.
Can either be a GenerationConfig object or a dictionary of parameters.
Accepted parameters are:
- temperature
- top_p
- top_k
- candidate_count
- max_output_tokens
- stop_sequences
:param safety_settings: The safety settings to use, defaults to None.
A dictionary of HarmCategory to HarmBlockThreshold.
:param tools: The tools to use, defaults to None.
A list of Tool objects that can be used to modify the generation process.
:param api_key: Google AI Studio API key.
:param model: Name of the model to use.
:param generation_config: The generation config to use.
Can either be a `GenerationConfig` object or a dictionary of parameters.
For the available parameters, see
[the `GenerationConfig` API reference](https://ai.google.dev/api/python/google/generativeai/GenerationConfig).
:param safety_settings: The safety settings to use.
A dictionary with `HarmCategory` as keys and `HarmBlockThreshold` as values.
For more information, see [the API reference](https://ai.google.dev/api)
:param tools: A list of Tool objects that can be used for [Function calling](https://ai.google.dev/docs/function_calling).
"""

# Authenticate, if api_key is None it will use the GOOGLE_API_KEY env variable
genai.configure(api_key=api_key.resolve_value())

self._api_key = api_key
Expand All @@ -155,6 +144,12 @@ def _generation_config_to_dict(self, config: Union[GenerationConfig, Dict[str, A
}

def to_dict(self) -> Dict[str, Any]:
"""
Serializes the component to a dictionary.
:returns:
Dictionary with serialized data.
"""
data = default_to_dict(
self,
api_key=self._api_key.to_dict(),
Expand All @@ -173,6 +168,14 @@ def to_dict(self) -> Dict[str, Any]:

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "GoogleAIGeminiChatGenerator":
"""
Deserializes the component from a dictionary.
:param data:
Dictionary to deserialize from.
:returns:
Deserialized component.
"""
deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"])

if (tools := data["init_parameters"].get("tools")) is not None:
Expand Down Expand Up @@ -247,6 +250,15 @@ def _message_to_content(self, message: ChatMessage) -> Content:

@component.output_types(replies=List[ChatMessage])
def run(self, messages: List[ChatMessage]):
"""
Generate text based on the provided messages.
:param messages:
A list of `ChatMessage` instances, representing the input messages.
:returns:
A dictionary containing the following key:
- `replies`: A list containing the generated responses as `ChatMessage` instances.
"""
history = [self._message_to_content(m) for m in messages[:-1]]
session = self._model.start_chat(history=history)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
@component
class GoogleAIGeminiGenerator:
"""
GoogleAIGeminiGenerator is a multi modal generator supporting Gemini via Google Makersuite.
`GoogleAIGeminiGenerator` is a multimodal generator supporting Gemini via Google AI Studio.
Sample usage:
Usage example:
```python
from haystack.utils import Secret
from haystack_integrations.components.generators.google_ai import GoogleAIGeminiGenerator
Expand All @@ -30,7 +30,7 @@ class GoogleAIGeminiGenerator:
print(answer)
```
This is a more advanced usage that also uses text and images as input:
Multimodal usage example:
```python
import requests
from haystack.utils import Secret
Expand Down Expand Up @@ -58,12 +58,6 @@ class GoogleAIGeminiGenerator:
for answer in result["answers"]:
print(answer)
```
Input:
- **parts** A eterogeneous list of strings, ByteStream or Part objects.
Output:
- **answers** A list of strings or dictionaries with function calls.
"""

def __init__(
Expand All @@ -76,7 +70,7 @@ def __init__(
tools: Optional[List[Tool]] = None,
):
"""
Initialize a GoogleAIGeminiGenerator instance.
Initialize a `GoogleAIGeminiGenerator` instance.
To get an API key, visit: https://makersuite.google.com
Expand All @@ -85,21 +79,16 @@ def __init__(
* `gemini-pro-vision`
* `gemini-ultra`
:param api_key: Google Makersuite API key.
:param model: Name of the model to use, defaults to "gemini-pro-vision"
:param generation_config: The generation config to use, defaults to None.
Can either be a GenerationConfig object or a dictionary of parameters.
Accepted parameters are:
- temperature
- top_p
- top_k
- candidate_count
- max_output_tokens
- stop_sequences
:param safety_settings: The safety settings to use, defaults to None.
A dictionary of HarmCategory to HarmBlockThreshold.
:param tools: The tools to use, defaults to None.
A list of Tool objects that can be used to modify the generation process.
:param api_key: Google AI Studio API key.
:param model: Name of the model to use.
:param generation_config: The generation config to use.
Can either be a `GenerationConfig` object or a dictionary of parameters.
For the available parameters, see
[the `GenerationConfig` API reference](https://ai.google.dev/api/python/google/generativeai/GenerationConfig).
:param safety_settings: The safety settings to use.
A dictionary with `HarmCategory` as keys and `HarmBlockThreshold` as values.
For more information, see [the API reference](https://ai.google.dev/api)
:param tools: A list of Tool objects that can be used for [Function calling](https://ai.google.dev/docs/function_calling).
"""
genai.configure(api_key=api_key.resolve_value())

Expand All @@ -123,6 +112,12 @@ def _generation_config_to_dict(self, config: Union[GenerationConfig, Dict[str, A
}

def to_dict(self) -> Dict[str, Any]:
"""
Serializes the component to a dictionary.
:returns:
Dictionary with serialized data.
"""
data = default_to_dict(
self,
api_key=self._api_key.to_dict(),
Expand All @@ -141,6 +136,14 @@ def to_dict(self) -> Dict[str, Any]:

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "GoogleAIGeminiGenerator":
"""
Deserializes the component from a dictionary.
:param data:
Dictionary to deserialize from.
:returns:
Deserialized component.
"""
deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"])

if (tools := data["init_parameters"].get("tools")) is not None:
Expand Down Expand Up @@ -172,6 +175,16 @@ def _convert_part(self, part: Union[str, ByteStream, Part]) -> Part:

@component.output_types(answers=List[Union[str, Dict[str, str]]])
def run(self, parts: Variadic[Union[str, ByteStream, Part]]):
"""
Generate text based on the given input parts.
:param parts:
A heterogeneous list of strings, `ByteStream` or `Part` objects.
:returns:
A dictionary containing the following key:
- `answers`: A list of strings or dictionaries with function calls.
"""

converted_parts = [self._convert_part(p) for p in parts]

contents = [Content(parts=converted_parts, role="user")]
Expand Down

0 comments on commit 03577e4

Please sign in to comment.