Skip to content

Commit

Permalink
Fixed support for custom tool (#64)
Browse files Browse the repository at this point in the history
* Fixed support for custom tool
  • Loading branch information
moiz-stri authored Mar 14, 2024
1 parent 75e6891 commit c91f98c
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 2 deletions.
12 changes: 10 additions & 2 deletions libs/vertexai/langchain_google_vertexai/functions_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,23 @@ def _format_tool_to_vertex_function(tool: BaseTool) -> FunctionDescription:


def _format_tools_to_vertex_tool(
tools: List[Union[BaseTool, Type[BaseModel]]],
tools: List[Union[BaseTool, Type[BaseModel], dict]],
) -> List[VertexTool]:
"Format tool into the Vertex Tool instance."
function_declarations = []
for tool in tools:
if isinstance(tool, BaseTool):
func = _format_tool_to_vertex_function(tool)
else:
elif isinstance(tool, type) and issubclass(tool, BaseModel):
func = _format_pydantic_to_vertex_function(tool)
elif isinstance(tool, dict):
func = {
"name": tool["name"],
"description": tool.pop("description"),
"parameters": _get_parameters_from_schema(tool["parameters"]),
}
else:
raise ValueError(f"Unsupported tool call type {tool}")
function_declarations.append(FunctionDeclaration(**func))

return [VertexTool(function_declarations=function_declarations)]
Expand Down
47 changes: 47 additions & 0 deletions libs/vertexai/tests/integration_tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,53 @@ def test_tools() -> None:
assert round(float(just_numbers), 2) == 2.16


@pytest.mark.extended
def test_custom_tool() -> None:
from langchain.agents import AgentExecutor, tool
from langchain.agents.format_scratchpad import (
format_to_openai_function_messages,
)

@tool("search", return_direct=True)
def search(query: str) -> str:
"""Look up things online."""
return "LangChain"

tools = [search]

llm = ChatVertexAI(
model_name="gemini-pro", temperature=0.0, convert_system_message_to_human=True
)
prompt = ChatPromptTemplate.from_messages(
[
("system", "You are a helpful assistant"),
MessagesPlaceholder("chat_history", optional=True),
("human", "{input}"),
MessagesPlaceholder("agent_scratchpad"),
]
)
llm_with_tools = llm.bind(functions=tools)

agent: Any = (
{
"input": lambda x: x["input"],
"agent_scratchpad": lambda x: format_to_openai_function_messages(
x["intermediate_steps"]
),
}
| prompt
| llm_with_tools
| _TestOutputParser()
)
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)

response = agent_executor.invoke({"input": "What is LangChain?"})
assert isinstance(response, dict)
assert response["input"] == "What is LangChain?"

assert "LangChain" in response["output"]


@pytest.mark.extended
def test_stream() -> None:
from langchain.chains import LLMMathChain
Expand Down

0 comments on commit c91f98c

Please sign in to comment.