Skip to content

Commit

Permalink
vertexai[patch]: enable agents for vertex (#137)
Browse files Browse the repository at this point in the history
* flesh out bind_tools

---------

Co-authored-by: Bagatur <[email protected]>
  • Loading branch information
ccurme and baskaryan authored Apr 11, 2024
1 parent e1ec207 commit 5523121
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 30 deletions.
63 changes: 58 additions & 5 deletions libs/vertexai/langchain_google_vertexai/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,20 @@
import logging
from dataclasses import dataclass, field
from operator import itemgetter
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Type, Union, cast
import uuid
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
Iterator,
List,
Optional,
Sequence,
Type,
Union,
cast,
)

import proto # type: ignore[import-untyped]
from google.cloud.aiplatform_v1beta1.types.content import Part as GapicPart
Expand Down Expand Up @@ -34,6 +47,7 @@
ToolCallChunk,
ToolMessage,
)
from langchain_core.tools import BaseTool, tool as tool_from_callable
from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.output_parsers.openai_functions import (
JsonOutputFunctionsParser,
Expand Down Expand Up @@ -206,9 +220,17 @@ def _convert_to_parts(message: BaseMessage) -> List[Part]:
]
elif isinstance(message, ToolMessage):
role = "function"
if (i > 0) and isinstance(history[i - 1], AIMessage):
# message.name can be null for ToolMessage
if history[i - 1].tool_calls: # type: ignore
name = history[i - 1].tool_calls[0]["name"] # type: ignore
else:
name = message.name
else:
name = message.name
parts = [
Part.from_function_response(
name=message.name,
name=name,
response={
"content": message.content,
},
Expand Down Expand Up @@ -313,7 +335,7 @@ def _parse_response_candidate(
ToolCallChunk(
name=function_call.get("name"),
args=function_call.get("arguments"),
id=function_call.get("id"),
id=function_call.get("id", str(uuid.uuid4())),
index=function_call.get("index"),
)
]
Expand All @@ -334,7 +356,7 @@ def _parse_response_candidate(
ToolCall(
name=tool_call["name"],
args=tool_call["args"],
id=tool_call.get("id"),
id=tool_call.get("id", str(uuid.uuid4())),
)
for tool_call in tool_calls_dicts
]
Expand All @@ -343,7 +365,7 @@ def _parse_response_candidate(
InvalidToolCall(
name=function_call.get("name"),
args=function_call.get("arguments"),
id=function_call.get("id"),
id=function_call.get("id", str(uuid.uuid4())),
error=str(e),
)
]
Expand Down Expand Up @@ -844,6 +866,37 @@ class AnswerWithJustification(BaseModel):
else:
return llm | parser

def bind_tools(
self,
tools: Sequence[Union[Type[BaseModel], Callable, BaseTool]],
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
"""Bind tool-like objects to this chat model.
Assumes model is compatible with Vertex tool-calling API.
Args:
tools: A list of tool definitions to bind to this chat model.
Can be a pydantic model, callable, or BaseTool. Pydantic
models, callables, and BaseTools will be automatically converted to
their schema dictionary representation.
**kwargs: Any additional parameters to pass to the
:class:`~langchain.runnable.Runnable` constructor.
"""
formatted_tools = []
for schema in tools:
if isinstance(schema, BaseTool) or (
isinstance(schema, type) and issubclass(schema, BaseModel)
):
formatted_tools.append(schema)
elif callable(schema):
formatted_tools.append(tool_from_callable(schema)) # type: ignore
else:
raise ValueError(
"Tool must be a BaseTool, Pydantic model, or callable."
)
return super().bind(functions=formatted_tools, **kwargs)

def _start_chat(
self, history: _ChatHistory, **kwargs: Any
) -> Union[ChatSession, CodeChatSession]:
Expand Down
81 changes: 56 additions & 25 deletions libs/vertexai/tests/integration_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
HumanMessage,
SystemMessage,
ToolCall,
ToolCallChunk,
)
from langchain_core.outputs import ChatGeneration, LLMResult
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.tools import tool

from langchain_google_vertexai import ChatVertexAI, HarmBlockThreshold, HarmCategory

Expand Down Expand Up @@ -262,37 +262,69 @@ def test_get_num_tokens_from_messages(model_name: str) -> None:
assert token == 3


@pytest.mark.release
def test_chat_vertexai_gemini_function_calling() -> None:
class MyModel(BaseModel):
name: str
age: int

safety = {
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH
}
model = ChatVertexAI(model_name="gemini-pro", safety_settings=safety).bind(
functions=[MyModel]
)
message = HumanMessage(content="My name is Erick and I am 27 years old")
response = model.invoke([message])
def _check_tool_calls(response: BaseMessage, expected_name: str) -> None:
"""Check tool calls are as expected."""
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
assert response.content == ""
function_call = response.additional_kwargs.get("function_call")
assert function_call
assert function_call["name"] == "MyModel"
assert function_call["name"] == expected_name
arguments_str = function_call.get("arguments")
assert arguments_str
arguments = json.loads(arguments_str)
assert arguments == {
"name": "Erick",
"age": 27.0,
}
assert response.tool_calls == [
ToolCall(name="MyModel", args={"age": 27.0, "name": "Erick"}, id=None)
]
tool_calls = response.tool_calls
assert len(tool_calls) == 1
tool_call = tool_calls[0]
assert tool_call["name"] == expected_name
assert tool_call["args"] == {"age": 27.0, "name": "Erick"}


@pytest.mark.release
def test_chat_vertexai_gemini_function_calling() -> None:
class MyModel(BaseModel):
name: str
age: int

safety = {
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH
}
# Test .bind_tools with BaseModel
message = HumanMessage(content="My name is Erick and I am 27 years old")
model = ChatVertexAI(model_name="gemini-pro", safety_settings=safety).bind_tools(
[MyModel]
)
response = model.invoke([message])
_check_tool_calls(response, "MyModel")

# Test .bind_tools with function
def my_model(name: str, age: int) -> None:
"""Invoke this with names and ages."""
pass

model = ChatVertexAI(model_name="gemini-pro", safety_settings=safety).bind_tools(
[my_model]
)
response = model.invoke([message])
_check_tool_calls(response, "my_model")

# Test .bind_tools with tool
@tool
def my_tool(name: str, age: int) -> None:
"""Invoke this with names and ages."""
pass

model = ChatVertexAI(model_name="gemini-pro", safety_settings=safety).bind_tools(
[my_tool]
)
response = model.invoke([message])
_check_tool_calls(response, "my_tool")

# Test streaming
stream = model.stream([message])
first = True
for chunk in stream:
Expand All @@ -302,8 +334,7 @@ class MyModel(BaseModel):
else:
gathered = gathered + chunk # type: ignore
assert isinstance(gathered, AIMessageChunk)
assert gathered.tool_call_chunks == [
ToolCallChunk(
name="MyModel", args='{"age": 27.0, "name": "Erick"}', id=None, index=None
)
]
assert len(gathered.tool_call_chunks) == 1
tool_call_chunk = gathered.tool_call_chunks[0]
assert tool_call_chunk["name"] == "my_tool"
assert tool_call_chunk["args"] == '{"age": 27.0, "name": "Erick"}'

0 comments on commit 5523121

Please sign in to comment.