diff --git a/libs/vertexai/langchain_google_vertexai/functions_utils.py b/libs/vertexai/langchain_google_vertexai/functions_utils.py index 2bf56608..29a103c0 100644 --- a/libs/vertexai/langchain_google_vertexai/functions_utils.py +++ b/libs/vertexai/langchain_google_vertexai/functions_utils.py @@ -62,8 +62,14 @@ def _format_tools_to_vertex_tool( for tool in tools: if isinstance(tool, BaseTool): func = _format_tool_to_vertex_function(tool) - else: + elif isinstance(tool, type) and issubclass(tool, BaseModel): func = _format_pydantic_to_vertex_function(tool) + else: + func = { + "name": tool["name"], + "description": tool.get("description"), + "parameters": _get_parameters_from_schema(tool["parameters"]), + } function_declarations.append(FunctionDeclaration(**func)) return [VertexTool(function_declarations=function_declarations)] diff --git a/libs/vertexai/tests/integration_tests/test_tools.py b/libs/vertexai/tests/integration_tests/test_tools.py index 9d378d6e..16750164 100644 --- a/libs/vertexai/tests/integration_tests/test_tools.py +++ b/libs/vertexai/tests/integration_tests/test_tools.py @@ -93,6 +93,38 @@ def test_tools() -> None: assert round(float(just_numbers), 2) == 2.16 +@pytest.mark.extended +def test_custom_tool() -> None: + from langchain.agents import AgentExecutor, create_openai_functions_agent + from langchain.agents import tool + + @tool + def search(query: str) -> str: + """Look up things online.""" + return "LangChain" + + tools = [search] + + llm = ChatVertexAI(model_name="gemini-pro", temperature=0.0, convert_system_message_to_human=True) + prompt = ChatPromptTemplate.from_messages( + [ + ("system", "You are a helpful assistant"), + MessagesPlaceholder("chat_history", optional=True), + ("human", "{input}"), + MessagesPlaceholder("agent_scratchpad"), + ] + ) + + agent = create_openai_functions_agent(llm, tools, prompt) + agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True) + + response = agent_executor.invoke({"input": "What is LangChain?"}) + assert isinstance(response, dict) + assert response["input"] == "What is LangChain?" + + assert "LangChain" in response["output"] + + @pytest.mark.extended def test_stream() -> None: from langchain.chains import LLMMathChain