Skip to content

Commit

Permalink
force tool call on with_structured_output (#140)
Browse files Browse the repository at this point in the history
  • Loading branch information
alx13 authored Apr 12, 2024
1 parent 62ba6ca commit 27a9fbf
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 2 deletions.
19 changes: 17 additions & 2 deletions libs/vertexai/langchain_google_vertexai/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@
GenerativeModel,
Part,
)
from vertexai.generative_models._generative_models import ( # type: ignore
ToolConfig,
)
from vertexai.language_models import ( # type: ignore
ChatMessage,
ChatModel,
Expand Down Expand Up @@ -860,9 +863,20 @@ class AnswerWithJustification(BaseModel):
parser: OutputParserLike = PydanticOutputFunctionsParser(
pydantic_schema=schema
)
name = schema.schema()["title"]
else:
parser = JsonOutputFunctionsParser()
llm = self.bind(functions=[schema])
name = schema["name"]

llm = self.bind(
functions=[schema],
tool_config={
"function_calling_config": {
"mode": ToolConfig.FunctionCallingConfig.Mode.ANY,
"allowed_function_names": [name],
}
},
)
if include_raw:
parser_with_fallback = RunnablePassthrough.assign(
parsed=itemgetter("raw") | parser, parsing_error=lambda _: None
Expand All @@ -877,6 +891,7 @@ class AnswerWithJustification(BaseModel):
def bind_tools(
self,
tools: Sequence[Union[Type[BaseModel], Callable, BaseTool]],
tool_config: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
"""Bind tool-like objects to this chat model.
Expand All @@ -903,7 +918,7 @@ def bind_tools(
raise ValueError(
"Tool must be a BaseTool, Pydantic model, or callable."
)
return super().bind(functions=formatted_tools, **kwargs)
return self.bind(functions=formatted_tools, tool_config=tool_config, **kwargs)

def _start_chat(
self, history: _ChatHistory, **kwargs: Any
Expand Down
27 changes: 27 additions & 0 deletions libs/vertexai/tests/integration_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,3 +408,30 @@ class MyModel(BaseModel):
assert response.content != ""
function_call = response.additional_kwargs.get("function_call")
assert function_call is None


@pytest.mark.release
def test_chat_vertexai_gemini_function_calling_with_structured_output() -> None:
class MyModel(BaseModel):
name: str
age: int

safety = {
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH
}
llm = ChatVertexAI(model_name="gemini-1.5-pro-preview-0409", safety_settings=safety)
model = llm.with_structured_output(MyModel)
message = HumanMessage(content="My name is Erick and I am 27 years old")

response = model.invoke([message])
assert isinstance(response, MyModel)
assert response == MyModel(name="Erick", age=27)

model = llm.with_structured_output(
{"name": "MyModel", "description": "MyModel", "parameters": MyModel.schema()}
)
response = model.invoke([message])
assert response == {
"name": "Erick",
"age": 27,
}

0 comments on commit 27a9fbf

Please sign in to comment.