Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added support for tools on VertexAI #14839

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 71 additions & 17 deletions libs/community/langchain_community/chat_models/vertexai.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
AIMessage,
AIMessageChunk,
BaseMessage,
FunctionMessage,
HumanMessage,
SystemMessage,
)
Expand All @@ -41,7 +42,7 @@
CodeChatSession,
InputOutputTextPair,
)
from vertexai.preview.generative_models import Content
from vertexai.preview.generative_models import Candidate, Content

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -89,6 +90,8 @@ def _parse_chat_history(history: List[BaseMessage]) -> _ChatHistory:
def _parse_chat_history_gemini(
history: List[BaseMessage], project: Optional[str]
) -> List["Content"]:
from google.cloud.aiplatform_v1beta1.types.content import Part as GapicPart
from google.cloud.aiplatform_v1beta1.types.tool import FunctionCall
from vertexai.preview.generative_models import Content, Image, Part

def _convert_to_prompt(part: Union[str, Dict]) -> Part:
Expand All @@ -113,23 +116,45 @@ def _convert_to_prompt(part: Union[str, Dict]) -> Part:
raise ValueError("Only text and image_url types are supported!")
return Part.from_image(image)

def _convert_to_parts(message: BaseMessage) -> List[Part]:
raw_content = message.content
if isinstance(raw_content, str):
raw_content = [raw_content]
return [_convert_to_prompt(part) for part in raw_content]

vertex_messages = []
for i, message in enumerate(history):
if i == 0 and isinstance(message, SystemMessage):
raise ValueError("SystemMessages are not yet supported!")
elif isinstance(message, AIMessage):
raw_function_call = message.additional_kwargs.get("function_call")
role = "model"
if raw_function_call:
function_call = FunctionCall(
{
"name": raw_function_call["name"],
"args": raw_function_call["arguments"],
}
)
gapic_part = GapicPart(function_call=function_call)
parts = [Part._from_gapic(gapic_part)]
elif isinstance(message, HumanMessage):
role = "user"
parts = _convert_to_parts(message)
elif isinstance(message, FunctionMessage):
role = "user"
parts = [
Part.from_function_response(
name=message.name,
response={
"content": message.content,
},
)
]
else:
raise ValueError(
f"Unexpected message with type {type(message)} at the position {i}."
)

raw_content = message.content
if isinstance(raw_content, str):
raw_content = [raw_content]
parts = [_convert_to_prompt(part) for part in raw_content]
vertex_message = Content(role=role, parts=parts)
vertex_messages.append(vertex_message)
return vertex_messages
Expand Down Expand Up @@ -177,6 +202,23 @@ def _get_question(messages: List[BaseMessage]) -> HumanMessage:
return question


def _parse_response_candidate(response_candidate: "Candidate") -> AIMessage:
try:
content = response_candidate.text
except ValueError:
content = ""

additional_kwargs = {}
first_part = response_candidate.content.parts[0]
if first_part.function_call:
function_call = {"name": first_part.function_call.name}
function_call["arguments"] = {
k: first_part.function_call.args[k] for k in first_part.function_call.args
}
additional_kwargs["function_call"] = function_call
return AIMessage(content=content, additional_kwargs=additional_kwargs)


class ChatVertexAI(_VertexAICommon, BaseChatModel):
"""`Vertex AI` Chat large language models API."""

Expand Down Expand Up @@ -247,7 +289,6 @@ def _generate(
)
return generate_from_stream(stream_iter)

question = _get_question(messages)
params = self._prepare_params(stop=stop, stream=False, **kwargs)
msg_params = {}
if "candidate_count" in params:
Expand All @@ -257,18 +298,24 @@ def _generate(
history_gemini = _parse_chat_history_gemini(messages, project=self.project)
message = history_gemini.pop()
chat = self.client.start_chat(history=history_gemini)
response = chat.send_message(message, generation_config=params)
tools = params.pop("tools") if "tools" in params else None
response = chat.send_message(message, generation_config=params, tools=tools)
generations = [
ChatGeneration(message=_parse_response_candidate(c))
for c in response.candidates
]
else:
question = _get_question(messages)
history = _parse_chat_history(messages[:-1])
examples = kwargs.get("examples") or self.examples
if examples:
params["examples"] = _parse_examples(examples)
chat = self._start_chat(history, **params)
response = chat.send_message(question.content, **msg_params)
generations = [
ChatGeneration(message=AIMessage(content=r.text))
for r in response.candidates
]
generations = [
ChatGeneration(message=AIMessage(content=r.text))
for r in response.candidates
]
return ChatResult(generations=generations)

async def _agenerate(
Expand Down Expand Up @@ -305,7 +352,14 @@ async def _agenerate(
history_gemini = _parse_chat_history_gemini(messages, project=self.project)
message = history_gemini.pop()
chat = self.client.start_chat(history=history_gemini)
response = await chat.send_message_async(message, generation_config=params)
tools = params.pop("tools") if "tools" in params else None
response = await chat.send_message_async(
message, generation_config=params, tools=tools
)
generations = [
ChatGeneration(message=_parse_response_candidate(c))
for c in response.candidates
]
else:
question = _get_question(messages)
history = _parse_chat_history(messages[:-1])
Expand All @@ -315,10 +369,10 @@ async def _agenerate(
chat = self._start_chat(history, **params)
response = await chat.send_message_async(question.content, **msg_params)

generations = [
ChatGeneration(message=AIMessage(content=r.text))
for r in response.candidates
]
generations = [
ChatGeneration(message=AIMessage(content=r.text))
for r in response.candidates
]
return ChatResult(generations=generations)

def _stream(
Expand Down
41 changes: 41 additions & 0 deletions libs/community/langchain_community/utils/vertex_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from typing import TYPE_CHECKING

from langchain_core.tools import Tool

from langchain_community.utils.openai_functions import (
FunctionDescription,
convert_pydantic_to_openai_function,
)

if TYPE_CHECKING:
from vertexai.preview.generative_models import Tool as VertexTool


def format_tool_to_vertex_function(tool: Tool) -> FunctionDescription:
"Format tool into the Vertex function API."
if tool.args_schema:
return convert_pydantic_to_openai_function(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i tried this out and it errors, i think you need to remove title from the json schema that is generated

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hwchase17 which error do you get?

I've just tried this again

raw_tools = [
    Tool(
        name="Search",
        func=search.run,
        description="useful for when you need to answer questions about current events. You should ask targeted questions",
    ),
    Tool(
        name="Calculator",
        func=llm_math_chain.run,
        description="useful for when you need to answer questions about math",
    )
]
tools = [format_tool_to_vertex_tool(tool) for tool in raw_tools]

and it seems to be working fine.

and there's also a working integration test:

tool.args_schema, name=tool.name, description=tool.description
)
else:
return {
"name": tool.name,
"description": tool.description,
"parameters": {
"properties": {
"__arg1": {"type": "string"},
},
"required": ["__arg1"],
"type": "object",
},
}


def format_tool_to_vertex_tool(tool: Tool) -> "VertexTool":
"Format tool into the Vertex Tool instance."

from vertexai.preview.generative_models import FunctionDeclaration
from vertexai.preview.generative_models import Tool as VertexTool

function_declaration = FunctionDeclaration(**format_tool_to_vertex_function(tool))
return VertexTool(function_declarations=[function_declaration])
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
_parse_chat_history,
_parse_examples,
)
from langchain_community.utils.vertex_functions import format_tool_to_vertex_tool

model_names_to_test = [None, "codechat-bison", "chat-bison", "gemini-pro"]

Expand Down Expand Up @@ -109,7 +110,7 @@ def test_vertexai_single_call_with_context() -> None:


def test_multimodal() -> None:
llm = ChatVertexAI(model_name="gemini-ultra-vision")
llm = ChatVertexAI(model_name="gemini-pro-vision")
gcs_url = (
"gs://cloud-samples-data/generative-ai/image/"
"320px-Felis_catus-cat_on_snow.jpg"
Expand All @@ -128,7 +129,7 @@ def test_multimodal() -> None:


def test_multimodal_history() -> None:
llm = ChatVertexAI(model_name="gemini-ultra-vision")
llm = ChatVertexAI(model_name="gemini-pro-vision")
gcs_url = (
"gs://cloud-samples-data/generative-ai/image/"
"320px-Felis_catus-cat_on_snow.jpg"
Expand Down Expand Up @@ -293,3 +294,47 @@ def test_parse_examples_failes_wrong_sequence() -> None:
str(exc_info.value)
== "Expect examples to have an even amount of messages, got 1."
)


def test_tools() -> None:
from langchain.agents import AgentExecutor, Tool
from langchain.agents.format_scratchpad import format_to_openai_function_messages
from langchain.agents.output_parsers import VertexAIFunctionsAgentOutputParser
from langchain.chains import LLMMathChain
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder

llm = ChatVertexAI(model_name="gemini-pro")
math_chain = LLMMathChain.from_llm(llm=llm)
raw_tools = [
Tool(
name="Calculator",
func=math_chain.run,
description="useful for when you need to answer questions about math",
)
]
tools = [format_tool_to_vertex_tool(t) for t in raw_tools]
prompt = ChatPromptTemplate.from_messages(
[
("user", "{input}"),
MessagesPlaceholder(variable_name="agent_scratchpad"),
]
)
llm_with_tools = llm.bind(tools=tools)

agent = (
{
"input": lambda x: x["input"],
"agent_scratchpad": lambda x: format_to_openai_function_messages(
x["intermediate_steps"]
),
}
| prompt
| llm_with_tools
| VertexAIFunctionsAgentOutputParser()
)
agent_executor = AgentExecutor(agent=agent, tools=raw_tools, verbose=True)

response = agent_executor.invoke({"input": "What is 6 raised to the 0.43 power?"})
assert isinstance(response, dict)
assert response["input"] == "What is 6 raised to the 0.43 power?"
assert round(float(response["output"]), 3) == 2.161
2 changes: 2 additions & 0 deletions libs/langchain/langchain/agents/output_parsers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from langchain.agents.output_parsers.json import JSONAgentOutputParser
from langchain.agents.output_parsers.openai_functions import (
OpenAIFunctionsAgentOutputParser,
VertexAIFunctionsAgentOutputParser,
)
from langchain.agents.output_parsers.react_json_single_input import (
ReActJsonSingleInputOutputParser,
Expand All @@ -29,4 +30,5 @@
"OpenAIFunctionsAgentOutputParser",
"XMLAgentOutputParser",
"JSONAgentOutputParser",
"VertexAIFunctionsAgentOutputParser",
]
Loading
Loading