Skip to content

Commit

Permalink
Merge branch 'main' into snowflake-privatekey-update
Browse files Browse the repository at this point in the history
  • Loading branch information
iireland-ii authored Nov 14, 2024
2 parents 8de2b06 + c2d1b20 commit 3f7a045
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 167 deletions.
10 changes: 10 additions & 0 deletions integrations/google_vertex/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
# Changelog

## [integrations/google_vertex-v3.0.0] - 2024-11-14

### 🐛 Bug Fixes

- VertexAIGeminiGenerator - remove support for tools and change output type (#1180)

### ⚙️ Miscellaneous Tasks

- Fix Vertex tests (#1163)

## [integrations/google_vertex-v2.2.0] - 2024-10-23

### 🐛 Bug Fixes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
HarmBlockThreshold,
HarmCategory,
Part,
Tool,
ToolConfig,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -50,6 +48,16 @@ class VertexAIGeminiGenerator:
```
"""

def __new__(cls, *_, **kwargs):
if "tools" in kwargs or "tool_config" in kwargs:
msg = (
"VertexAIGeminiGenerator does not support `tools` and `tool_config` parameters. "
"Use VertexAIGeminiChatGenerator instead."
)
raise TypeError(msg)
return super(VertexAIGeminiGenerator, cls).__new__(cls) # noqa: UP008
# super(__class__, cls) is needed because of the component decorator

def __init__(
self,
*,
Expand All @@ -58,8 +66,6 @@ def __init__(
location: Optional[str] = None,
generation_config: Optional[Union[GenerationConfig, Dict[str, Any]]] = None,
safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None,
tools: Optional[List[Tool]] = None,
tool_config: Optional[ToolConfig] = None,
system_instruction: Optional[Union[str, ByteStream, Part]] = None,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
):
Expand All @@ -86,10 +92,6 @@ def __init__(
for [HarmBlockThreshold](https://cloud.google.com/python/docs/reference/aiplatform/latest/vertexai.generative_models.HarmBlockThreshold)
and [HarmCategory](https://cloud.google.com/python/docs/reference/aiplatform/latest/vertexai.generative_models.HarmCategory)
for more details.
:param tools: List of tools to use when generating content. See the documentation for
[Tool](https://cloud.google.com/python/docs/reference/aiplatform/latest/vertexai.generative_models.Tool)
the list of supported arguments.
:param tool_config: The tool config to use. See the documentation for [ToolConfig](https://cloud.google.com/vertex-ai/generative-ai/docs/reference/python/latest/vertexai.generative_models.ToolConfig)
:param system_instruction: Default system instruction to use for generating content.
:param streaming_callback: A callback function that is called when a new token is received from the stream.
The callback function accepts StreamingChunk as an argument.
Expand All @@ -105,8 +107,6 @@ def __init__(
# model parameters
self._generation_config = generation_config
self._safety_settings = safety_settings
self._tools = tools
self._tool_config = tool_config
self._system_instruction = system_instruction
self._streaming_callback = streaming_callback

Expand All @@ -115,8 +115,6 @@ def __init__(
self._model_name,
generation_config=self._generation_config,
safety_settings=self._safety_settings,
tools=self._tools,
tool_config=self._tool_config,
system_instruction=self._system_instruction,
)

Expand All @@ -132,18 +130,6 @@ def _generation_config_to_dict(self, config: Union[GenerationConfig, Dict[str, A
"stop_sequences": config._raw_generation_config.stop_sequences,
}

def _tool_config_to_dict(self, tool_config: ToolConfig) -> Dict[str, Any]:
"""Serializes the ToolConfig object into a dictionary."""

mode = tool_config._gapic_tool_config.function_calling_config.mode
allowed_function_names = tool_config._gapic_tool_config.function_calling_config.allowed_function_names
config_dict = {"function_calling_config": {"mode": mode}}

if allowed_function_names:
config_dict["function_calling_config"]["allowed_function_names"] = allowed_function_names

return config_dict

def to_dict(self) -> Dict[str, Any]:
"""
Serializes the component to a dictionary.
Expand All @@ -160,15 +146,10 @@ def to_dict(self) -> Dict[str, Any]:
location=self._location,
generation_config=self._generation_config,
safety_settings=self._safety_settings,
tools=self._tools,
tool_config=self._tool_config,
system_instruction=self._system_instruction,
streaming_callback=callback_name,
)
if (tools := data["init_parameters"].get("tools")) is not None:
data["init_parameters"]["tools"] = [Tool.to_dict(t) for t in tools]
if (tool_config := data["init_parameters"].get("tool_config")) is not None:
data["init_parameters"]["tool_config"] = self._tool_config_to_dict(tool_config)

if (generation_config := data["init_parameters"].get("generation_config")) is not None:
data["init_parameters"]["generation_config"] = self._generation_config_to_dict(generation_config)
return data
Expand All @@ -184,22 +165,8 @@ def from_dict(cls, data: Dict[str, Any]) -> "VertexAIGeminiGenerator":
Deserialized component.
"""

def _tool_config_from_dict(config_dict: Dict[str, Any]) -> ToolConfig:
"""Deserializes the ToolConfig object from a dictionary."""
function_calling_config = config_dict["function_calling_config"]
return ToolConfig(
function_calling_config=ToolConfig.FunctionCallingConfig(
mode=function_calling_config["mode"],
allowed_function_names=function_calling_config.get("allowed_function_names"),
)
)

if (tools := data["init_parameters"].get("tools")) is not None:
data["init_parameters"]["tools"] = [Tool.from_dict(t) for t in tools]
if (generation_config := data["init_parameters"].get("generation_config")) is not None:
data["init_parameters"]["generation_config"] = GenerationConfig.from_dict(generation_config)
if (tool_config := data["init_parameters"].get("tool_config")) is not None:
data["init_parameters"]["tool_config"] = _tool_config_from_dict(tool_config)
if (serialized_callback_handler := data["init_parameters"].get("streaming_callback")) is not None:
data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler)
return default_from_dict(cls, data)
Expand All @@ -215,7 +182,7 @@ def _convert_part(self, part: Union[str, ByteStream, Part]) -> Part:
msg = f"Unsupported type {type(part)} for part {part}"
raise ValueError(msg)

@component.output_types(replies=List[Union[str, Dict[str, str]]])
@component.output_types(replies=List[str])
def run(
self,
parts: Variadic[Union[str, ByteStream, Part]],
Expand Down Expand Up @@ -257,12 +224,6 @@ def _get_response(self, response_body: GenerationResponse) -> List[str]:
for part in candidate.content.parts:
if part._raw_part.text != "":
replies.append(part.text)
elif part.function_call is not None:
function_call = {
"name": part.function_call.name,
"args": dict(part.function_call.args.items()),
}
replies.append(function_call)
return replies

def _get_stream_response(
Expand Down
125 changes: 9 additions & 116 deletions integrations/google_vertex/tests/test_gemini.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,17 @@
from unittest.mock import MagicMock, Mock, patch

import pytest
from haystack import Pipeline
from haystack.components.builders import PromptBuilder
from haystack.dataclasses import StreamingChunk
from vertexai.generative_models import (
FunctionDeclaration,
GenerationConfig,
HarmBlockThreshold,
HarmCategory,
Tool,
ToolConfig,
)

from haystack_integrations.components.generators.google_vertex import VertexAIGeminiGenerator

GET_CURRENT_WEATHER_FUNC = FunctionDeclaration(
name="get_current_weather",
description="Get the current weather in a given location",
parameters={
"type": "object",
"properties": {
"location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"},
"unit": {
"type": "string",
"enum": [
"celsius",
"fahrenheit",
],
},
},
"required": ["location"],
},
)


@patch("haystack_integrations.components.generators.google_vertex.gemini.vertexai_init")
@patch("haystack_integrations.components.generators.google_vertex.gemini.GenerativeModel")
Expand All @@ -48,32 +27,28 @@ def test_init(mock_vertexai_init, _mock_generative_model):
)
safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH}

tool = Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC])
tool_config = ToolConfig(
function_calling_config=ToolConfig.FunctionCallingConfig(
mode=ToolConfig.FunctionCallingConfig.Mode.ANY,
allowed_function_names=["get_current_weather_func"],
)
)

gemini = VertexAIGeminiGenerator(
project_id="TestID123",
location="TestLocation",
generation_config=generation_config,
safety_settings=safety_settings,
tools=[tool],
tool_config=tool_config,
system_instruction="Please provide brief answers.",
)
mock_vertexai_init.assert_called()
assert gemini._model_name == "gemini-1.5-flash"
assert gemini._generation_config == generation_config
assert gemini._safety_settings == safety_settings
assert gemini._tools == [tool]
assert gemini._tool_config == tool_config
assert gemini._system_instruction == "Please provide brief answers."


def test_init_fails_with_tools_or_tool_config():
with pytest.raises(TypeError, match="VertexAIGeminiGenerator does not support `tools`"):
VertexAIGeminiGenerator(tools=["tool1", "tool2"])

with pytest.raises(TypeError, match="VertexAIGeminiGenerator does not support `tools`"):
VertexAIGeminiGenerator(tool_config={"custom": "config"})


@patch("haystack_integrations.components.generators.google_vertex.gemini.vertexai_init")
@patch("haystack_integrations.components.generators.google_vertex.gemini.GenerativeModel")
def test_to_dict(_mock_vertexai_init, _mock_generative_model):
Expand All @@ -88,8 +63,6 @@ def test_to_dict(_mock_vertexai_init, _mock_generative_model):
"generation_config": None,
"safety_settings": None,
"streaming_callback": None,
"tools": None,
"tool_config": None,
"system_instruction": None,
},
}
Expand All @@ -108,21 +81,11 @@ def test_to_dict_with_params(_mock_vertexai_init, _mock_generative_model):
)
safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH}

tool = Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC])
tool_config = ToolConfig(
function_calling_config=ToolConfig.FunctionCallingConfig(
mode=ToolConfig.FunctionCallingConfig.Mode.ANY,
allowed_function_names=["get_current_weather_func"],
)
)

gemini = VertexAIGeminiGenerator(
project_id="TestID123",
location="TestLocation",
generation_config=generation_config,
safety_settings=safety_settings,
tools=[tool],
tool_config=tool_config,
system_instruction="Please provide brief answers.",
)
assert gemini.to_dict() == {
Expand All @@ -141,34 +104,6 @@ def test_to_dict_with_params(_mock_vertexai_init, _mock_generative_model):
},
"safety_settings": {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH},
"streaming_callback": None,
"tools": [
{
"function_declarations": [
{
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type_": "OBJECT",
"properties": {
"location": {
"type_": "STRING",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type_": "STRING", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
"property_ordering": ["location", "unit"],
},
}
]
}
],
"tool_config": {
"function_calling_config": {
"mode": ToolConfig.FunctionCallingConfig.Mode.ANY,
"allowed_function_names": ["get_current_weather_func"],
}
},
"system_instruction": "Please provide brief answers.",
},
}
Expand All @@ -186,9 +121,7 @@ def test_from_dict(_mock_vertexai_init, _mock_generative_model):
"model": "gemini-1.5-flash",
"generation_config": None,
"safety_settings": None,
"tools": None,
"streaming_callback": None,
"tool_config": None,
"system_instruction": None,
},
}
Expand All @@ -198,8 +131,6 @@ def test_from_dict(_mock_vertexai_init, _mock_generative_model):
assert gemini._project_id is None
assert gemini._location is None
assert gemini._safety_settings is None
assert gemini._tools is None
assert gemini._tool_config is None
assert gemini._system_instruction is None
assert gemini._generation_config is None

Expand All @@ -223,40 +154,7 @@ def test_from_dict_with_param(_mock_vertexai_init, _mock_generative_model):
"stop_sequences": ["stop"],
},
"safety_settings": {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH},
"tools": [
{
"function_declarations": [
{
"name": "get_current_weather",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {
"type": "string",
"enum": [
"celsius",
"fahrenheit",
],
},
},
"required": ["location"],
},
"description": "Get the current weather in a given location",
}
]
}
],
"streaming_callback": None,
"tool_config": {
"function_calling_config": {
"mode": ToolConfig.FunctionCallingConfig.Mode.ANY,
"allowed_function_names": ["get_current_weather_func"],
}
},
"system_instruction": "Please provide brief answers.",
},
}
Expand All @@ -266,13 +164,8 @@ def test_from_dict_with_param(_mock_vertexai_init, _mock_generative_model):
assert gemini._project_id == "TestID123"
assert gemini._location == "TestLocation"
assert gemini._safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH}
assert repr(gemini._tools) == repr([Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC])])
assert isinstance(gemini._generation_config, GenerationConfig)
assert isinstance(gemini._tool_config, ToolConfig)
assert gemini._system_instruction == "Please provide brief answers."
assert (
gemini._tool_config._gapic_tool_config.function_calling_config.mode == ToolConfig.FunctionCallingConfig.Mode.ANY
)


@patch("haystack_integrations.components.generators.google_vertex.gemini.GenerativeModel")
Expand Down

0 comments on commit 3f7a045

Please sign in to comment.