From 27a9fbfa7317bb239eba6d5070b20bbde8fd07fb Mon Sep 17 00:00:00 2001 From: Alex Ostapenko Date: Fri, 12 Apr 2024 10:29:44 +0200 Subject: [PATCH] force tool call on with_structured_output (#140) --- .../langchain_google_vertexai/chat_models.py | 19 +++++++++++-- .../integration_tests/test_chat_models.py | 27 +++++++++++++++++++ 2 files changed, 44 insertions(+), 2 deletions(-) diff --git a/libs/vertexai/langchain_google_vertexai/chat_models.py b/libs/vertexai/langchain_google_vertexai/chat_models.py index 5a601261..80b5e68c 100644 --- a/libs/vertexai/langchain_google_vertexai/chat_models.py +++ b/libs/vertexai/langchain_google_vertexai/chat_models.py @@ -63,6 +63,9 @@ GenerativeModel, Part, ) +from vertexai.generative_models._generative_models import ( # type: ignore + ToolConfig, +) from vertexai.language_models import ( # type: ignore ChatMessage, ChatModel, @@ -860,9 +863,20 @@ class AnswerWithJustification(BaseModel): parser: OutputParserLike = PydanticOutputFunctionsParser( pydantic_schema=schema ) + name = schema.schema()["title"] else: parser = JsonOutputFunctionsParser() - llm = self.bind(functions=[schema]) + name = schema["name"] + + llm = self.bind( + functions=[schema], + tool_config={ + "function_calling_config": { + "mode": ToolConfig.FunctionCallingConfig.Mode.ANY, + "allowed_function_names": [name], + } + }, + ) if include_raw: parser_with_fallback = RunnablePassthrough.assign( parsed=itemgetter("raw") | parser, parsing_error=lambda _: None @@ -877,6 +891,7 @@ class AnswerWithJustification(BaseModel): def bind_tools( self, tools: Sequence[Union[Type[BaseModel], Callable, BaseTool]], + tool_config: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> Runnable[LanguageModelInput, BaseMessage]: """Bind tool-like objects to this chat model. @@ -903,7 +918,7 @@ def bind_tools( raise ValueError( "Tool must be a BaseTool, Pydantic model, or callable." ) - return super().bind(functions=formatted_tools, **kwargs) + return self.bind(functions=formatted_tools, tool_config=tool_config, **kwargs) def _start_chat( self, history: _ChatHistory, **kwargs: Any diff --git a/libs/vertexai/tests/integration_tests/test_chat_models.py b/libs/vertexai/tests/integration_tests/test_chat_models.py index 95a07752..91627cc9 100644 --- a/libs/vertexai/tests/integration_tests/test_chat_models.py +++ b/libs/vertexai/tests/integration_tests/test_chat_models.py @@ -408,3 +408,30 @@ class MyModel(BaseModel): assert response.content != "" function_call = response.additional_kwargs.get("function_call") assert function_call is None + + +@pytest.mark.release +def test_chat_vertexai_gemini_function_calling_with_structured_output() -> None: + class MyModel(BaseModel): + name: str + age: int + + safety = { + HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH + } + llm = ChatVertexAI(model_name="gemini-1.5-pro-preview-0409", safety_settings=safety) + model = llm.with_structured_output(MyModel) + message = HumanMessage(content="My name is Erick and I am 27 years old") + + response = model.invoke([message]) + assert isinstance(response, MyModel) + assert response == MyModel(name="Erick", age=27) + + model = llm.with_structured_output( + {"name": "MyModel", "description": "MyModel", "parameters": MyModel.schema()} + ) + response = model.invoke([message]) + assert response == { + "name": "Erick", + "age": 27, + }