diff --git a/libs/partners/google-vertexai/tests/integration_tests/test_chat_models.py b/libs/partners/google-vertexai/tests/integration_tests/test_chat_models.py index f2981a0801812..94c7ea6a55b79 100644 --- a/libs/partners/google-vertexai/tests/integration_tests/test_chat_models.py +++ b/libs/partners/google-vertexai/tests/integration_tests/test_chat_models.py @@ -1,4 +1,6 @@ """Test ChatGoogleVertexAI chat model.""" + +import json from typing import Optional, cast import pytest @@ -9,6 +11,7 @@ SystemMessage, ) from langchain_core.outputs import ChatGeneration, LLMResult +from langchain_core.pydantic_v1 import BaseModel from langchain_google_vertexai.chat_models import ChatVertexAI @@ -220,3 +223,26 @@ def test_chat_vertexai_system_message(model_name: Optional[str]) -> None: response = model([system_message, message1, message2, message3]) assert isinstance(response, AIMessage) assert isinstance(response.content, str) + + +def test_chat_vertexai_gemini_function_calling() -> None: + class MyModel(BaseModel): + name: str + age: int + + model = ChatVertexAI(model_name="gemini-pro").bind(functions=[MyModel]) + message = HumanMessage(content="My name is Erick and I am 27 years old") + response = model.invoke([message]) + assert isinstance(response, AIMessage) + assert isinstance(response.content, str) + assert response.content == "" + function_call = response.additional_kwargs.get("function_call") + assert function_call + assert function_call["name"] == "MyModel" + arguments_str = function_call.get("arguments") + assert arguments_str + arguments = json.loads(arguments_str) + assert arguments == { + "name": "Erick", + "age": 27.0, + }