Skip to content

Commit

Permalink
fix: remove the use of deprecated gemini models (#1032)
Browse files Browse the repository at this point in the history
  • Loading branch information
Amnah199 authored Aug 28, 2024
1 parent 50352b9 commit e49f4e1
Show file tree
Hide file tree
Showing 7 changed files with 31 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,7 @@ class GoogleAIGeminiChatGenerator:
Completes chats using multimodal Gemini models through Google AI Studio.
It uses the [`ChatMessage`](https://docs.haystack.deepset.ai/docs/data-classes#chatmessage)
dataclass to interact with the model. You can use the following models:
- gemini-pro
- gemini-ultra
- gemini-pro-vision
dataclass to interact with the model.
### Usage example
Expand Down Expand Up @@ -103,7 +100,7 @@ def __init__(
self,
*,
api_key: Secret = Secret.from_env_var("GOOGLE_API_KEY"), # noqa: B008
model: str = "gemini-pro-vision",
model: str = "gemini-1.5-flash",
generation_config: Optional[Union[GenerationConfig, Dict[str, Any]]] = None,
safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None,
tools: Optional[List[Tool]] = None,
Expand All @@ -114,17 +111,9 @@ def __init__(
To get an API key, visit: https://makersuite.google.com
It supports the following models:
* `gemini-pro`
* `gemini-pro-vision`
* `gemini-ultra`
:param api_key: Google AI Studio API key. To get a key,
see [Google AI Studio](https://makersuite.google.com).
:param model: Name of the model to use. Supported models are:
- gemini-pro
- gemini-ultra
- gemini-pro-vision
:param model: Name of the model to use. For available models, see https://ai.google.dev/gemini-api/docs/models/gemini.
:param generation_config: The generation configuration to use.
This can either be a `GenerationConfig` object or a dictionary of parameters.
For available parameters, see
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class GoogleAIGeminiGenerator:
for url in URLS
]
gemini = GoogleAIGeminiGenerator(model="gemini-pro-vision", api_key=Secret.from_token("<MY_API_KEY>"))
gemini = GoogleAIGeminiGenerator(model="gemini-1.5-flash", api_key=Secret.from_token("<MY_API_KEY>"))
result = gemini.run(parts = ["What can you tell me about this robots?", *images])
for answer in result["replies"]:
print(answer)
Expand All @@ -66,7 +66,7 @@ def __init__(
self,
*,
api_key: Secret = Secret.from_env_var("GOOGLE_API_KEY"), # noqa: B008
model: str = "gemini-pro-vision",
model: str = "gemini-1.5-flash",
generation_config: Optional[Union[GenerationConfig, Dict[str, Any]]] = None,
safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None,
tools: Optional[List[Tool]] = None,
Expand All @@ -77,13 +77,8 @@ def __init__(
To get an API key, visit: https://makersuite.google.com
It supports the following models:
* `gemini-pro`
* `gemini-pro-vision`
* `gemini-ultra`
:param api_key: Google AI Studio API key.
:param model: Name of the model to use.
:param model: Name of the model to use. For available models, see https://ai.google.dev/gemini-api/docs/models/gemini
:param generation_config: The generation configuration to use.
This can either be a `GenerationConfig` object or a dictionary of parameters.
For available parameters, see
Expand Down
14 changes: 7 additions & 7 deletions integrations/google_ai/tests/generators/chat/test_chat_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_init(monkeypatch):
tools=[tool],
)
mock_genai_configure.assert_called_once_with(api_key="test")
assert gemini._model_name == "gemini-pro-vision"
assert gemini._model_name == "gemini-1.5-flash"
assert gemini._generation_config == generation_config
assert gemini._safety_settings == safety_settings
assert gemini._tools == [tool]
Expand All @@ -67,7 +67,7 @@ def test_to_dict(monkeypatch):
"type": "haystack_integrations.components.generators.google_ai.chat.gemini.GoogleAIGeminiChatGenerator",
"init_parameters": {
"api_key": {"env_vars": ["GOOGLE_API_KEY"], "strict": True, "type": "env_var"},
"model": "gemini-pro-vision",
"model": "gemini-1.5-flash",
"generation_config": None,
"safety_settings": None,
"streaming_callback": None,
Expand Down Expand Up @@ -100,7 +100,7 @@ def test_to_dict_with_param(monkeypatch):
"type": "haystack_integrations.components.generators.google_ai.chat.gemini.GoogleAIGeminiChatGenerator",
"init_parameters": {
"api_key": {"env_vars": ["GOOGLE_API_KEY"], "strict": True, "type": "env_var"},
"model": "gemini-pro-vision",
"model": "gemini-1.5-flash",
"generation_config": {
"temperature": 0.5,
"top_p": 0.5,
Expand Down Expand Up @@ -129,7 +129,7 @@ def test_from_dict(monkeypatch):
"type": "haystack_integrations.components.generators.google_ai.chat.gemini.GoogleAIGeminiChatGenerator",
"init_parameters": {
"api_key": {"env_vars": ["GOOGLE_API_KEY"], "strict": True, "type": "env_var"},
"model": "gemini-pro-vision",
"model": "gemini-1.5-flash",
"generation_config": None,
"safety_settings": None,
"streaming_callback": None,
Expand All @@ -138,7 +138,7 @@ def test_from_dict(monkeypatch):
}
)

assert gemini._model_name == "gemini-pro-vision"
assert gemini._model_name == "gemini-1.5-flash"
assert gemini._generation_config is None
assert gemini._safety_settings is None
assert gemini._tools is None
Expand All @@ -154,7 +154,7 @@ def test_from_dict_with_param(monkeypatch):
"type": "haystack_integrations.components.generators.google_ai.chat.gemini.GoogleAIGeminiChatGenerator",
"init_parameters": {
"api_key": {"env_vars": ["GOOGLE_API_KEY"], "strict": True, "type": "env_var"},
"model": "gemini-pro-vision",
"model": "gemini-1.5-flash",
"generation_config": {
"temperature": 0.5,
"top_p": 0.5,
Expand All @@ -174,7 +174,7 @@ def test_from_dict_with_param(monkeypatch):
}
)

assert gemini._model_name == "gemini-pro-vision"
assert gemini._model_name == "gemini-1.5-flash"
assert gemini._generation_config == GenerationConfig(
candidate_count=1,
stop_sequences=["stop"],
Expand Down
14 changes: 7 additions & 7 deletions integrations/google_ai/tests/generators/test_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def test_init(monkeypatch):
tools=[tool],
)
mock_genai_configure.assert_called_once_with(api_key="test")
assert gemini._model_name == "gemini-pro-vision"
assert gemini._model_name == "gemini-1.5-flash"
assert gemini._generation_config == generation_config
assert gemini._safety_settings == safety_settings
assert gemini._tools == [tool]
Expand All @@ -83,7 +83,7 @@ def test_to_dict(monkeypatch):
assert gemini.to_dict() == {
"type": "haystack_integrations.components.generators.google_ai.gemini.GoogleAIGeminiGenerator",
"init_parameters": {
"model": "gemini-pro-vision",
"model": "gemini-1.5-flash",
"api_key": {"env_vars": ["GOOGLE_API_KEY"], "strict": True, "type": "env_var"},
"generation_config": None,
"safety_settings": None,
Expand Down Expand Up @@ -135,7 +135,7 @@ def test_to_dict_with_param(monkeypatch):
assert gemini.to_dict() == {
"type": "haystack_integrations.components.generators.google_ai.gemini.GoogleAIGeminiGenerator",
"init_parameters": {
"model": "gemini-pro-vision",
"model": "gemini-1.5-flash",
"api_key": {"env_vars": ["GOOGLE_API_KEY"], "strict": True, "type": "env_var"},
"generation_config": {
"temperature": 0.5,
Expand Down Expand Up @@ -164,7 +164,7 @@ def test_from_dict_with_param(monkeypatch):
{
"type": "haystack_integrations.components.generators.google_ai.gemini.GoogleAIGeminiGenerator",
"init_parameters": {
"model": "gemini-pro-vision",
"model": "gemini-1.5-flash",
"generation_config": {
"temperature": 0.5,
"top_p": 0.5,
Expand All @@ -184,7 +184,7 @@ def test_from_dict_with_param(monkeypatch):
}
)

assert gemini._model_name == "gemini-pro-vision"
assert gemini._model_name == "gemini-1.5-flash"
assert gemini._generation_config == GenerationConfig(
candidate_count=1,
stop_sequences=["stop"],
Expand All @@ -206,7 +206,7 @@ def test_from_dict(monkeypatch):
{
"type": "haystack_integrations.components.generators.google_ai.gemini.GoogleAIGeminiGenerator",
"init_parameters": {
"model": "gemini-pro-vision",
"model": "gemini-1.5-flash",
"generation_config": {
"temperature": 0.5,
"top_p": 0.5,
Expand All @@ -226,7 +226,7 @@ def test_from_dict(monkeypatch):
}
)

assert gemini._model_name == "gemini-pro-vision"
assert gemini._model_name == "gemini-1.5-flash"
assert gemini._generation_config == GenerationConfig(
candidate_count=1,
stop_sequences=["stop"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,6 @@ class VertexAIGeminiChatGenerator:
"""
`VertexAIGeminiChatGenerator` enables chat completion using Google Gemini models.
`VertexAIGeminiChatGenerator` supports both `gemini-pro` and `gemini-pro-vision` models.
Prompting with images requires `gemini-pro-vision`. Function calling, instead, requires `gemini-pro`.
Authenticates using Google Cloud Application Default Credentials (ADCs).
For more information see the official [Google documentation](https://cloud.google.com/docs/authentication/provide-credentials-adc).
Expand All @@ -51,7 +48,7 @@ class VertexAIGeminiChatGenerator:
def __init__(
self,
*,
model: str = "gemini-pro",
model: str = "gemini-1.5-flash",
project_id: str,
location: Optional[str] = None,
generation_config: Optional[Union[GenerationConfig, Dict[str, Any]]] = None,
Expand All @@ -66,7 +63,7 @@ def __init__(
For more information see the official [Google documentation](https://cloud.google.com/docs/authentication/provide-credentials-adc).
:param project_id: ID of the GCP project to use.
:param model: Name of the model to use, defaults to "gemini-pro-vision".
:param model: Name of the model to use. For available models, see https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models.
:param location: The default location to use when making API calls, if not set uses us-central-1.
Defaults to None.
:param generation_config: Configuration for the generation process.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,6 @@ class VertexAIGeminiGenerator:
"""
`VertexAIGeminiGenerator` enables text generation using Google Gemini models.
`VertexAIGeminiGenerator` supports both `gemini-pro` and `gemini-pro-vision` models.
Prompting with images requires `gemini-pro-vision`. Function calling, instead, requires `gemini-pro`.
Usage example:
```python
from haystack_integrations.components.generators.google_vertex import VertexAIGeminiGenerator
Expand All @@ -55,7 +52,7 @@ class VertexAIGeminiGenerator:
def __init__(
self,
*,
model: str = "gemini-pro-vision",
model: str = "gemini-1.5-flash",
project_id: str,
location: Optional[str] = None,
generation_config: Optional[Union[GenerationConfig, Dict[str, Any]]] = None,
Expand All @@ -70,7 +67,7 @@ def __init__(
For more information see the official [Google documentation](https://cloud.google.com/docs/authentication/provide-credentials-adc).
:param project_id: ID of the GCP project to use.
:param model: Name of the model to use.
:param model: Name of the model to use. For available models, see https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models.
:param location: The default location to use when making API calls, if not set uses us-central-1.
:param generation_config: The generation config to use.
Can either be a [`GenerationConfig`](https://cloud.google.com/python/docs/reference/aiplatform/latest/vertexai.preview.generative_models.GenerationConfig)
Expand Down
14 changes: 7 additions & 7 deletions integrations/google_vertex/tests/test_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def test_init(mock_vertexai_init, _mock_generative_model):
tools=[tool],
)
mock_vertexai_init.assert_called()
assert gemini._model_name == "gemini-pro-vision"
assert gemini._model_name == "gemini-1.5-flash"
assert gemini._generation_config == generation_config
assert gemini._safety_settings == safety_settings
assert gemini._tools == [tool]
Expand All @@ -71,7 +71,7 @@ def test_to_dict(_mock_vertexai_init, _mock_generative_model):
assert gemini.to_dict() == {
"type": "haystack_integrations.components.generators.google_vertex.gemini.VertexAIGeminiGenerator",
"init_parameters": {
"model": "gemini-pro-vision",
"model": "gemini-1.5-flash",
"project_id": "TestID123",
"location": None,
"generation_config": None,
Expand Down Expand Up @@ -106,7 +106,7 @@ def test_to_dict_with_params(_mock_vertexai_init, _mock_generative_model):
assert gemini.to_dict() == {
"type": "haystack_integrations.components.generators.google_vertex.gemini.VertexAIGeminiGenerator",
"init_parameters": {
"model": "gemini-pro-vision",
"model": "gemini-1.5-flash",
"project_id": "TestID123",
"location": None,
"generation_config": {
Expand Down Expand Up @@ -152,7 +152,7 @@ def test_from_dict(_mock_vertexai_init, _mock_generative_model):
"type": "haystack_integrations.components.generators.google_vertex.gemini.VertexAIGeminiGenerator",
"init_parameters": {
"project_id": "TestID123",
"model": "gemini-pro-vision",
"model": "gemini-1.5-flash",
"generation_config": None,
"safety_settings": None,
"tools": None,
Expand All @@ -161,7 +161,7 @@ def test_from_dict(_mock_vertexai_init, _mock_generative_model):
}
)

assert gemini._model_name == "gemini-pro-vision"
assert gemini._model_name == "gemini-1.5-flash"
assert gemini._project_id == "TestID123"
assert gemini._safety_settings is None
assert gemini._tools is None
Expand All @@ -176,7 +176,7 @@ def test_from_dict_with_param(_mock_vertexai_init, _mock_generative_model):
"type": "haystack_integrations.components.generators.google_vertex.gemini.VertexAIGeminiGenerator",
"init_parameters": {
"project_id": "TestID123",
"model": "gemini-pro-vision",
"model": "gemini-1.5-flash",
"generation_config": {
"temperature": 0.5,
"top_p": 0.5,
Expand Down Expand Up @@ -212,7 +212,7 @@ def test_from_dict_with_param(_mock_vertexai_init, _mock_generative_model):
}
)

assert gemini._model_name == "gemini-pro-vision"
assert gemini._model_name == "gemini-1.5-flash"
assert gemini._project_id == "TestID123"
assert gemini._safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH}
assert repr(gemini._tools) == repr([Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC])])
Expand Down

0 comments on commit e49f4e1

Please sign in to comment.