diff --git a/libs/vertexai/langchain_google_vertexai/chat_models.py b/libs/vertexai/langchain_google_vertexai/chat_models.py index 3f60067a..8a75a584 100644 --- a/libs/vertexai/langchain_google_vertexai/chat_models.py +++ b/libs/vertexai/langchain_google_vertexai/chat_models.py @@ -6,7 +6,20 @@ import logging from dataclasses import dataclass, field from operator import itemgetter -from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Type, Union, cast +import uuid +from typing import ( + Any, + AsyncIterator, + Callable, + Dict, + Iterator, + List, + Optional, + Sequence, + Type, + Union, + cast, +) import proto # type: ignore[import-untyped] from google.cloud.aiplatform_v1beta1.types.content import Part as GapicPart @@ -34,6 +47,7 @@ ToolCallChunk, ToolMessage, ) +from langchain_core.tools import BaseTool, tool as tool_from_callable from langchain_core.output_parsers.base import OutputParserLike from langchain_core.output_parsers.openai_functions import ( JsonOutputFunctionsParser, @@ -206,9 +220,17 @@ def _convert_to_parts(message: BaseMessage) -> List[Part]: ] elif isinstance(message, ToolMessage): role = "function" + if (i > 0) and isinstance(history[i - 1], AIMessage): + # message.name can be null for ToolMessage + if history[i - 1].tool_calls: # type: ignore + name = history[i - 1].tool_calls[0]["name"] # type: ignore + else: + name = message.name + else: + name = message.name parts = [ Part.from_function_response( - name=message.name, + name=name, response={ "content": message.content, }, @@ -313,7 +335,7 @@ def _parse_response_candidate( ToolCallChunk( name=function_call.get("name"), args=function_call.get("arguments"), - id=function_call.get("id"), + id=function_call.get("id", str(uuid.uuid4())), index=function_call.get("index"), ) ] @@ -334,7 +356,7 @@ def _parse_response_candidate( ToolCall( name=tool_call["name"], args=tool_call["args"], - id=tool_call.get("id"), + id=tool_call.get("id", str(uuid.uuid4())), ) for tool_call in tool_calls_dicts ] @@ -343,7 +365,7 @@ def _parse_response_candidate( InvalidToolCall( name=function_call.get("name"), args=function_call.get("arguments"), - id=function_call.get("id"), + id=function_call.get("id", str(uuid.uuid4())), error=str(e), ) ] @@ -844,6 +866,37 @@ class AnswerWithJustification(BaseModel): else: return llm | parser + def bind_tools( + self, + tools: Sequence[Union[Type[BaseModel], Callable, BaseTool]], + **kwargs: Any, + ) -> Runnable[LanguageModelInput, BaseMessage]: + """Bind tool-like objects to this chat model. + + Assumes model is compatible with Vertex tool-calling API. + + Args: + tools: A list of tool definitions to bind to this chat model. + Can be a pydantic model, callable, or BaseTool. Pydantic + models, callables, and BaseTools will be automatically converted to + their schema dictionary representation. + **kwargs: Any additional parameters to pass to the + :class:`~langchain.runnable.Runnable` constructor. + """ + formatted_tools = [] + for schema in tools: + if isinstance(schema, BaseTool) or ( + isinstance(schema, type) and issubclass(schema, BaseModel) + ): + formatted_tools.append(schema) + elif callable(schema): + formatted_tools.append(tool_from_callable(schema)) # type: ignore + else: + raise ValueError( + "Tool must be a BaseTool, Pydantic model, or callable." + ) + return super().bind(functions=formatted_tools, **kwargs) + def _start_chat( self, history: _ChatHistory, **kwargs: Any ) -> Union[ChatSession, CodeChatSession]: diff --git a/libs/vertexai/tests/integration_tests/test_chat_models.py b/libs/vertexai/tests/integration_tests/test_chat_models.py index f60acfe7..3a30d56d 100644 --- a/libs/vertexai/tests/integration_tests/test_chat_models.py +++ b/libs/vertexai/tests/integration_tests/test_chat_models.py @@ -7,13 +7,13 @@ from langchain_core.messages import ( AIMessage, AIMessageChunk, + BaseMessage, HumanMessage, SystemMessage, - ToolCall, - ToolCallChunk, ) from langchain_core.outputs import ChatGeneration, LLMResult from langchain_core.pydantic_v1 import BaseModel +from langchain_core.tools import tool from langchain_google_vertexai import ChatVertexAI, HarmBlockThreshold, HarmCategory @@ -262,26 +262,14 @@ def test_get_num_tokens_from_messages(model_name: str) -> None: assert token == 3 -@pytest.mark.release -def test_chat_vertexai_gemini_function_calling() -> None: - class MyModel(BaseModel): - name: str - age: int - - safety = { - HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH - } - model = ChatVertexAI(model_name="gemini-pro", safety_settings=safety).bind( - functions=[MyModel] - ) - message = HumanMessage(content="My name is Erick and I am 27 years old") - response = model.invoke([message]) +def _check_tool_calls(response: BaseMessage, expected_name: str) -> None: + """Check tool calls are as expected.""" assert isinstance(response, AIMessage) assert isinstance(response.content, str) assert response.content == "" function_call = response.additional_kwargs.get("function_call") assert function_call - assert function_call["name"] == "MyModel" + assert function_call["name"] == expected_name arguments_str = function_call.get("arguments") assert arguments_str arguments = json.loads(arguments_str) @@ -289,10 +277,54 @@ class MyModel(BaseModel): "name": "Erick", "age": 27.0, } - assert response.tool_calls == [ - ToolCall(name="MyModel", args={"age": 27.0, "name": "Erick"}, id=None) - ] + tool_calls = response.tool_calls + assert len(tool_calls) == 1 + tool_call = tool_calls[0] + assert tool_call["name"] == expected_name + assert tool_call["args"] == {"age": 27.0, "name": "Erick"} + +@pytest.mark.release +def test_chat_vertexai_gemini_function_calling() -> None: + class MyModel(BaseModel): + name: str + age: int + + safety = { + HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH + } + # Test .bind_tools with BaseModel + message = HumanMessage(content="My name is Erick and I am 27 years old") + model = ChatVertexAI(model_name="gemini-pro", safety_settings=safety).bind_tools( + [MyModel] + ) + response = model.invoke([message]) + _check_tool_calls(response, "MyModel") + + # Test .bind_tools with function + def my_model(name: str, age: int) -> None: + """Invoke this with names and ages.""" + pass + + model = ChatVertexAI(model_name="gemini-pro", safety_settings=safety).bind_tools( + [my_model] + ) + response = model.invoke([message]) + _check_tool_calls(response, "my_model") + + # Test .bind_tools with tool + @tool + def my_tool(name: str, age: int) -> None: + """Invoke this with names and ages.""" + pass + + model = ChatVertexAI(model_name="gemini-pro", safety_settings=safety).bind_tools( + [my_tool] + ) + response = model.invoke([message]) + _check_tool_calls(response, "my_tool") + + # Test streaming stream = model.stream([message]) first = True for chunk in stream: @@ -302,8 +334,7 @@ class MyModel(BaseModel): else: gathered = gathered + chunk # type: ignore assert isinstance(gathered, AIMessageChunk) - assert gathered.tool_call_chunks == [ - ToolCallChunk( - name="MyModel", args='{"age": 27.0, "name": "Erick"}', id=None, index=None - ) - ] + assert len(gathered.tool_call_chunks) == 1 + tool_call_chunk = gathered.tool_call_chunks[0] + assert tool_call_chunk["name"] == "my_tool" + assert tool_call_chunk["args"] == '{"age": 27.0, "name": "Erick"}'