Skip to content

Commit

Permalink
fixed the way how function definitions is constructed
Browse files Browse the repository at this point in the history
  • Loading branch information
alx13 committed Mar 25, 2024
1 parent d7bc26a commit 51ee061
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 24 deletions.
49 changes: 26 additions & 23 deletions libs/vertexai/langchain_google_vertexai/functions_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import json
from typing import Any, Dict, List, Type, Union
from typing import Any, Dict, List, Optional, Type, Union

from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers import BaseOutputParser
from langchain_core.outputs import ChatGeneration, Generation
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import FunctionDescription
from langchain_core.utils.json_schema import dereference_refs
Expand All @@ -19,8 +19,7 @@
def _format_pydantic_to_vertex_function(
pydantic_model: Type[BaseModel],
) -> FunctionDescription:
schema = dereference_refs(pydantic_model.schema())
schema.pop("definitions", None)
schema = pydantic_model.schema()

return {
"name": schema["title"],
Expand All @@ -32,8 +31,7 @@ def _format_pydantic_to_vertex_function(
def _format_tool_to_vertex_function(tool: BaseTool) -> FunctionDescription:
"Format tool into the Vertex function API."
if tool.args_schema:
schema = dereference_refs(tool.args_schema.schema())
schema.pop("definitions", None)
schema = tool.args_schema.schema()

return {
"name": tool.name or schema["title"],
Expand Down Expand Up @@ -77,6 +75,25 @@ def _format_tools_to_vertex_tool(
return [VertexTool(function_declarations=function_declarations)]


class ParametersSchema(BaseModel):
"""
This is a schema of currently supported definitions in function calling.
We need explicitly exclude `title` and `definitions` fields as they
are not currently supported.
All other fields will be passed through (as extra fields are allowed)
and intercepted on `google.cloud.aiplatform` level
"""

title: Optional[str] = Field(exclude=True)
definitions: Optional[Any] = Field(exclude=True)
items: Optional["ParametersSchema"]
properties: Optional[Dict[str, "ParametersSchema"]]

class Config:
extra = "allow"


def _get_parameters_from_schema(schema: Dict[str, Any]) -> Dict[str, Any]:
"""Given a schema, format the parameters key to match VertexAI
expected input.
Expand All @@ -88,24 +105,10 @@ def _get_parameters_from_schema(schema: Dict[str, Any]) -> Dict[str, Any]:
Dictionary with the formatted parameters.
"""

parameters = {}

parameters["type"] = schema["type"]

if "required" in schema:
parameters["required"] = schema["required"]

schema_properties: Dict[str, Any] = schema.get("properties", {})

parameters["properties"] = {
parameter_name: {
"type": parameter_dict["type"],
"description": parameter_dict.get("description"),
}
for parameter_name, parameter_dict in schema_properties.items()
}
dereferenced_schema = dereference_refs(schema)
model = ParametersSchema.parse_obj(dereferenced_schema)

return parameters
return model.dict(exclude_unset=True)


class PydanticFunctionsOutputParser(BaseOutputParser):
Expand Down
43 changes: 43 additions & 0 deletions libs/vertexai/tests/integration_tests/test_tools.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import os
import re
from typing import Any, List, Union
Expand All @@ -8,6 +9,7 @@
from langchain_core.output_parsers import BaseOutputParser
from langchain_core.outputs import ChatGeneration, Generation
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.tools import Tool

from langchain_google_vertexai.chat_models import ChatVertexAI
Expand Down Expand Up @@ -140,6 +142,47 @@ def search(query: str) -> str:
assert "LangChain" in response["output"]


@pytest.mark.extended
def test_tool_nested_properties() -> None:
from langchain.agents import tool

class Movie(BaseModel):
actor: str = Field(description="Actor in the film")
director: str = Field(description="Director of the film")

class Input(BaseModel):
movie: Movie = Field(description="Movie parameters object")

@tool("movie_search", return_direct=True)
def movie_search(input: Input) -> str:
"""Return last movie title by actor and director."""
return "Pulp Fiction"

tools = [movie_search]

llm = ChatVertexAI(
model_name="gemini-pro", temperature=0.0, convert_system_message_to_human=True
)
llm_with_tools = llm.bind(functions=tools)

response = llm_with_tools.invoke(
"What was the last movie directed by Quentin Tarantino with Bruce Willis?"
)

assert "function_call" in response.additional_kwargs
function_call = response.additional_kwargs["function_call"]

assert function_call["name"] == "movie_search"

assert "arguments" in function_call

arguments = json.loads(function_call["arguments"])
assert "input" in arguments
assert "movie" in arguments["input"]
assert "actor" in arguments["input"]["movie"]
assert "director" in arguments["input"]["movie"]


@pytest.mark.extended
def test_stream() -> None:
from langchain.chains import LLMMathChain
Expand Down
36 changes: 35 additions & 1 deletion libs/vertexai/tests/unit_tests/test_function_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
from typing import Optional, Union

from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.tools import tool

from langchain_google_vertexai.functions_utils import _format_tool_to_vertex_function
from langchain_google_vertexai.functions_utils import (
_format_tool_to_vertex_function,
_get_parameters_from_schema,
)


def test_format_tool_to_vertex_function():
Expand Down Expand Up @@ -43,3 +49,31 @@ def do_something_optional(a: float, b: float = 0) -> str:
assert schema["name"] == "do_something_optional"
assert "parameters" in schema
assert len(schema["parameters"]["required"]) == 1


def test_get_parameters_from_schema():
class A(BaseModel):
a1: Optional[int]

class B(BaseModel):
b1: Optional[A]
b2: int = Field(description="f2")
b3: Union[int, str]

schema = B.schema()
result = _get_parameters_from_schema(schema)
assert result["type"] == "object"
assert "required" in result
assert len(result["required"]) == 2

assert "properties" in result
assert "b1" in result["properties"]
assert "b2" in result["properties"]
assert "b3" in result["properties"]

assert result["properties"]["b1"]["type"] == "object"
assert "a1" in result["properties"]["b1"]["properties"]
assert "required" not in result["properties"]["b1"]
assert len(result["properties"]["b1"]["properties"]) == 1

assert "anyOf" in result["properties"]["b3"]

0 comments on commit 51ee061

Please sign in to comment.