Skip to content

Commit

Permalink
Merge branch 'main' into update-ruff-commands-settings
Browse files Browse the repository at this point in the history
  • Loading branch information
anakin87 authored Sep 26, 2024
2 parents 95f387a + cdb7dff commit 579bd7a
Show file tree
Hide file tree
Showing 6 changed files with 189 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
HarmCategory,
Part,
Tool,
ToolConfig,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -54,6 +55,8 @@ def __init__(
generation_config: Optional[Union[GenerationConfig, Dict[str, Any]]] = None,
safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None,
tools: Optional[List[Tool]] = None,
tool_config: Optional[ToolConfig] = None,
system_instruction: Optional[Union[str, ByteStream, Part]] = None,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
):
"""
Expand All @@ -76,8 +79,11 @@ def __init__(
:param tools: List of tools to use when generating content. See the documentation for
[Tool](https://cloud.google.com/python/docs/reference/aiplatform/latest/vertexai.generative_models.Tool)
the list of supported arguments.
:param tool_config: The tool config to use. See the documentation for [ToolConfig]
(https://cloud.google.com/vertex-ai/generative-ai/docs/reference/python/latest/vertexai.generative_models.ToolConfig)
:param system_instruction: Default system instruction to use for generating content.
:param streaming_callback: A callback function that is called when a new token is received from
the stream. The callback function accepts StreamingChunk as an argument.
the stream. The callback function accepts StreamingChunk as an argument.
"""

Expand All @@ -87,13 +93,25 @@ def __init__(
self._model_name = model
self._project_id = project_id
self._location = location
self._model = GenerativeModel(self._model_name)

# model parameters
self._generation_config = generation_config
self._safety_settings = safety_settings
self._tools = tools
self._tool_config = tool_config
self._system_instruction = system_instruction
self._streaming_callback = streaming_callback

# except streaming_callback, all other model parameters can be passed during initialization
self._model = GenerativeModel(
self._model_name,
generation_config=self._generation_config,
safety_settings=self._safety_settings,
tools=self._tools,
tool_config=self._tool_config,
system_instruction=self._system_instruction,
)

def _generation_config_to_dict(self, config: Union[GenerationConfig, Dict[str, Any]]) -> Dict[str, Any]:
if isinstance(config, dict):
return config
Expand All @@ -106,6 +124,17 @@ def _generation_config_to_dict(self, config: Union[GenerationConfig, Dict[str, A
"stop_sequences": config._raw_generation_config.stop_sequences,
}

def _tool_config_to_dict(self, tool_config: ToolConfig) -> Dict[str, Any]:
"""Serializes the ToolConfig object into a dictionary."""
mode = tool_config._gapic_tool_config.function_calling_config.mode
allowed_function_names = tool_config._gapic_tool_config.function_calling_config.allowed_function_names
config_dict = {"function_calling_config": {"mode": mode}}

if allowed_function_names:
config_dict["function_calling_config"]["allowed_function_names"] = allowed_function_names

return config_dict

def to_dict(self) -> Dict[str, Any]:
"""
Serializes the component to a dictionary.
Expand All @@ -123,10 +152,14 @@ def to_dict(self) -> Dict[str, Any]:
generation_config=self._generation_config,
safety_settings=self._safety_settings,
tools=self._tools,
tool_config=self._tool_config,
system_instruction=self._system_instruction,
streaming_callback=callback_name,
)
if (tools := data["init_parameters"].get("tools")) is not None:
data["init_parameters"]["tools"] = [Tool.to_dict(t) for t in tools]
if (tool_config := data["init_parameters"].get("tool_config")) is not None:
data["init_parameters"]["tool_config"] = self._tool_config_to_dict(tool_config)
if (generation_config := data["init_parameters"].get("generation_config")) is not None:
data["init_parameters"]["generation_config"] = self._generation_config_to_dict(generation_config)
return data
Expand All @@ -141,10 +174,23 @@ def from_dict(cls, data: Dict[str, Any]) -> "VertexAIGeminiChatGenerator":
:returns:
Deserialized component.
"""

def _tool_config_from_dict(config_dict: Dict[str, Any]) -> ToolConfig:
"""Deserializes the ToolConfig object from a dictionary."""
function_calling_config = config_dict["function_calling_config"]
return ToolConfig(
function_calling_config=ToolConfig.FunctionCallingConfig(
mode=function_calling_config["mode"],
allowed_function_names=function_calling_config.get("allowed_function_names"),
)
)

if (tools := data["init_parameters"].get("tools")) is not None:
data["init_parameters"]["tools"] = [Tool.from_dict(t) for t in tools]
if (generation_config := data["init_parameters"].get("generation_config")) is not None:
data["init_parameters"]["generation_config"] = GenerationConfig.from_dict(generation_config)
if (tool_config := data["init_parameters"].get("tool_config")) is not None:
data["init_parameters"]["tool_config"] = _tool_config_from_dict(tool_config)
if (serialized_callback_handler := data["init_parameters"].get("streaming_callback")) is not None:
data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler)
return default_from_dict(cls, data)
Expand Down Expand Up @@ -212,9 +258,6 @@ def run(
new_message = self._message_to_part(messages[-1])
res = session.send_message(
content=new_message,
generation_config=self._generation_config,
safety_settings=self._safety_settings,
tools=self._tools,
stream=streaming_callback is not None,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
HarmCategory,
Part,
Tool,
ToolConfig,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -58,6 +59,8 @@ def __init__(
generation_config: Optional[Union[GenerationConfig, Dict[str, Any]]] = None,
safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None,
tools: Optional[List[Tool]] = None,
tool_config: Optional[ToolConfig] = None,
system_instruction: Optional[Union[str, ByteStream, Part]] = None,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
):
"""
Expand Down Expand Up @@ -86,6 +89,8 @@ def __init__(
:param tools: List of tools to use when generating content. See the documentation for
[Tool](https://cloud.google.com/python/docs/reference/aiplatform/latest/vertexai.generative_models.Tool)
the list of supported arguments.
:param tool_config: The tool config to use. See the documentation for [ToolConfig](https://cloud.google.com/vertex-ai/generative-ai/docs/reference/python/latest/vertexai.generative_models.ToolConfig)
:param system_instruction: Default system instruction to use for generating content.
:param streaming_callback: A callback function that is called when a new token is received from the stream.
The callback function accepts StreamingChunk as an argument.
"""
Expand All @@ -96,13 +101,25 @@ def __init__(
self._model_name = model
self._project_id = project_id
self._location = location
self._model = GenerativeModel(self._model_name)

# model parameters
self._generation_config = generation_config
self._safety_settings = safety_settings
self._tools = tools
self._tool_config = tool_config
self._system_instruction = system_instruction
self._streaming_callback = streaming_callback

# except streaming_callback, all other model parameters can be passed during initialization
self._model = GenerativeModel(
self._model_name,
generation_config=self._generation_config,
safety_settings=self._safety_settings,
tools=self._tools,
tool_config=self._tool_config,
system_instruction=self._system_instruction,
)

def _generation_config_to_dict(self, config: Union[GenerationConfig, Dict[str, Any]]) -> Dict[str, Any]:
if isinstance(config, dict):
return config
Expand All @@ -115,6 +132,18 @@ def _generation_config_to_dict(self, config: Union[GenerationConfig, Dict[str, A
"stop_sequences": config._raw_generation_config.stop_sequences,
}

def _tool_config_to_dict(self, tool_config: ToolConfig) -> Dict[str, Any]:
"""Serializes the ToolConfig object into a dictionary."""

mode = tool_config._gapic_tool_config.function_calling_config.mode
allowed_function_names = tool_config._gapic_tool_config.function_calling_config.allowed_function_names
config_dict = {"function_calling_config": {"mode": mode}}

if allowed_function_names:
config_dict["function_calling_config"]["allowed_function_names"] = allowed_function_names

return config_dict

def to_dict(self) -> Dict[str, Any]:
"""
Serializes the component to a dictionary.
Expand All @@ -132,10 +161,14 @@ def to_dict(self) -> Dict[str, Any]:
generation_config=self._generation_config,
safety_settings=self._safety_settings,
tools=self._tools,
tool_config=self._tool_config,
system_instruction=self._system_instruction,
streaming_callback=callback_name,
)
if (tools := data["init_parameters"].get("tools")) is not None:
data["init_parameters"]["tools"] = [Tool.to_dict(t) for t in tools]
if (tool_config := data["init_parameters"].get("tool_config")) is not None:
data["init_parameters"]["tool_config"] = self._tool_config_to_dict(tool_config)
if (generation_config := data["init_parameters"].get("generation_config")) is not None:
data["init_parameters"]["generation_config"] = self._generation_config_to_dict(generation_config)
return data
Expand All @@ -150,10 +183,23 @@ def from_dict(cls, data: Dict[str, Any]) -> "VertexAIGeminiGenerator":
:returns:
Deserialized component.
"""

def _tool_config_from_dict(config_dict: Dict[str, Any]) -> ToolConfig:
"""Deserializes the ToolConfig object from a dictionary."""
function_calling_config = config_dict["function_calling_config"]
return ToolConfig(
function_calling_config=ToolConfig.FunctionCallingConfig(
mode=function_calling_config["mode"],
allowed_function_names=function_calling_config.get("allowed_function_names"),
)
)

if (tools := data["init_parameters"].get("tools")) is not None:
data["init_parameters"]["tools"] = [Tool.from_dict(t) for t in tools]
if (generation_config := data["init_parameters"].get("generation_config")) is not None:
data["init_parameters"]["generation_config"] = GenerationConfig.from_dict(generation_config)
if (tool_config := data["init_parameters"].get("tool_config")) is not None:
data["init_parameters"]["tool_config"] = _tool_config_from_dict(tool_config)
if (serialized_callback_handler := data["init_parameters"].get("streaming_callback")) is not None:
data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler)
return default_from_dict(cls, data)
Expand Down Expand Up @@ -188,11 +234,9 @@ def run(
converted_parts = [self._convert_part(p) for p in parts]

contents = [Content(parts=converted_parts, role="user")]

res = self._model.generate_content(
contents=contents,
generation_config=self._generation_config,
safety_settings=self._safety_settings,
tools=self._tools,
stream=streaming_callback is not None,
)
self._model.start_chat()
Expand Down
42 changes: 42 additions & 0 deletions integrations/google_vertex/tests/chat/test_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
HarmCategory,
Part,
Tool,
ToolConfig,
)

from haystack_integrations.components.generators.google_vertex import VertexAIGeminiChatGenerator
Expand Down Expand Up @@ -60,19 +61,29 @@ def test_init(mock_vertexai_init, _mock_generative_model):
safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH}

tool = Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC])
tool_config = ToolConfig(
function_calling_config=ToolConfig.FunctionCallingConfig(
mode=ToolConfig.FunctionCallingConfig.Mode.ANY,
allowed_function_names=["get_current_weather_func"],
)
)

gemini = VertexAIGeminiChatGenerator(
project_id="TestID123",
location="TestLocation",
generation_config=generation_config,
safety_settings=safety_settings,
tools=[tool],
tool_config=tool_config,
system_instruction="Please provide brief answers.",
)
mock_vertexai_init.assert_called()
assert gemini._model_name == "gemini-1.5-flash"
assert gemini._generation_config == generation_config
assert gemini._safety_settings == safety_settings
assert gemini._tools == [tool]
assert gemini._tool_config == tool_config
assert gemini._system_instruction == "Please provide brief answers."


@patch("haystack_integrations.components.generators.google_vertex.chat.gemini.vertexai_init")
Expand All @@ -92,6 +103,8 @@ def test_to_dict(_mock_vertexai_init, _mock_generative_model):
"safety_settings": None,
"streaming_callback": None,
"tools": None,
"tool_config": None,
"system_instruction": None,
},
}

Expand All @@ -110,12 +123,20 @@ def test_to_dict_with_params(_mock_vertexai_init, _mock_generative_model):
safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH}

tool = Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC])
tool_config = ToolConfig(
function_calling_config=ToolConfig.FunctionCallingConfig(
mode=ToolConfig.FunctionCallingConfig.Mode.ANY,
allowed_function_names=["get_current_weather_func"],
)
)

gemini = VertexAIGeminiChatGenerator(
project_id="TestID123",
generation_config=generation_config,
safety_settings=safety_settings,
tools=[tool],
tool_config=tool_config,
system_instruction="Please provide brief answers.",
)

assert gemini.to_dict() == {
Expand Down Expand Up @@ -155,6 +176,13 @@ def test_to_dict_with_params(_mock_vertexai_init, _mock_generative_model):
]
}
],
"tool_config": {
"function_calling_config": {
"mode": ToolConfig.FunctionCallingConfig.Mode.ANY,
"allowed_function_names": ["get_current_weather_func"],
}
},
"system_instruction": "Please provide brief answers.",
},
}

Expand All @@ -180,6 +208,8 @@ def test_from_dict(_mock_vertexai_init, _mock_generative_model):
assert gemini._project_id == "TestID123"
assert gemini._safety_settings is None
assert gemini._tools is None
assert gemini._tool_config is None
assert gemini._system_instruction is None
assert gemini._generation_config is None


Expand Down Expand Up @@ -222,6 +252,13 @@ def test_from_dict_with_param(_mock_vertexai_init, _mock_generative_model):
]
}
],
"tool_config": {
"function_calling_config": {
"mode": ToolConfig.FunctionCallingConfig.Mode.ANY,
"allowed_function_names": ["get_current_weather_func"],
}
},
"system_instruction": "Please provide brief answers.",
"streaming_callback": None,
},
}
Expand All @@ -231,7 +268,12 @@ def test_from_dict_with_param(_mock_vertexai_init, _mock_generative_model):
assert gemini._project_id == "TestID123"
assert gemini._safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH}
assert repr(gemini._tools) == repr([Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC])])
assert isinstance(gemini._tool_config, ToolConfig)
assert isinstance(gemini._generation_config, GenerationConfig)
assert gemini._system_instruction == "Please provide brief answers."
assert (
gemini._tool_config._gapic_tool_config.function_calling_config.mode == ToolConfig.FunctionCallingConfig.Mode.ANY
)


@patch("haystack_integrations.components.generators.google_vertex.chat.gemini.GenerativeModel")
Expand Down
Loading

0 comments on commit 579bd7a

Please sign in to comment.