Skip to content

Commit

Permalink
openai, anthropic, ...: udpate with_structured_output to pass in expl…
Browse files Browse the repository at this point in the history
…icit tool choice
  • Loading branch information
baskaryan committed Jun 28, 2024
1 parent 8fce8c6 commit af5255e
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 13 deletions.
4 changes: 3 additions & 1 deletion libs/partners/anthropic/langchain_anthropic/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -986,7 +986,9 @@ class AnswerWithJustification(BaseModel):
# }
""" # noqa: E501
llm = self.bind_tools([schema], tool_choice="any")

tool_name = convert_to_anthropic_tool(schema)["name"]
llm = self.bind_tools([schema], tool_choice=tool_name)
if isinstance(schema, type) and issubclass(schema, BaseModel):
output_parser = ToolsOutputParser(
first_tool_only=True, pydantic_schemas=[schema]
Expand Down
6 changes: 3 additions & 3 deletions libs/partners/fireworks/langchain_fireworks/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -884,15 +884,15 @@ class AnswerWithJustification(BaseModel):
"schema must be specified when method is 'function_calling'. "
"Received None."
)
llm = self.bind_tools([schema], tool_choice=True)
tool_name = convert_to_openai_tool(schema)["function"]["name"]
llm = self.bind_tools([schema], tool_choice=tool_name)
if is_pydantic_schema:
output_parser: OutputParserLike = PydanticToolsParser(
tools=[schema], first_tool_only=True
)
else:
key_name = convert_to_openai_tool(schema)["function"]["name"]
output_parser = JsonOutputKeyToolsParser(
key_name=key_name, first_tool_only=True
key_name=tool_name, first_tool_only=True
)
elif method == "json_mode":
llm = self.bind(response_format={"type": "json_object"})
Expand Down
6 changes: 3 additions & 3 deletions libs/partners/groq/langchain_groq/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1014,15 +1014,15 @@ class AnswerWithJustification(BaseModel):
"schema must be specified when method is 'function_calling'. "
"Received None."
)
llm = self.bind_tools([schema], tool_choice=True)
tool_name = convert_to_openai_tool(schema)["function"]["name"]
llm = self.bind_tools([schema], tool_choice=tool_name)
if is_pydantic_schema:
output_parser: OutputParserLike = PydanticToolsParser(
tools=[schema], first_tool_only=True
)
else:
key_name = convert_to_openai_tool(schema)["function"]["name"]
output_parser = JsonOutputKeyToolsParser(
key_name=key_name, first_tool_only=True
key_name=tool_name, first_tool_only=True
)
elif method == "json_mode":
llm = self.bind(response_format={"type": "json_object"})
Expand Down
6 changes: 3 additions & 3 deletions libs/partners/mistralai/langchain_mistralai/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,15 +794,15 @@ class AnswerWithJustification(BaseModel):
"schema must be specified when method is 'function_calling'. "
"Received None."
)
llm = self.bind_tools([schema], tool_choice="any")
tool_name = convert_to_openai_tool(schema)["function"]["name"]
llm = self.bind_tools([schema], tool_choice=tool_name)
if is_pydantic_schema:
output_parser: OutputParserLike = PydanticToolsParser(
tools=[schema], first_tool_only=True
)
else:
key_name = convert_to_openai_tool(schema)["function"]["name"]
output_parser = JsonOutputKeyToolsParser(
key_name=key_name, first_tool_only=True
key_name=tool_name, first_tool_only=True
)
elif method == "json_mode":
llm = self.bind(response_format={"type": "json_object"})
Expand Down
8 changes: 5 additions & 3 deletions libs/partners/openai/langchain_openai/chat_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1140,15 +1140,17 @@ class AnswerWithJustification(BaseModel):
"schema must be specified when method is 'function_calling'. "
"Received None."
)
llm = self.bind_tools([schema], tool_choice=True, parallel_tool_calls=False)
tool_name = convert_to_openai_tool(schema)["function"]["name"]
llm = self.bind_tools(
[schema], tool_choice=tool_name, parallel_tool_calls=False
)
if is_pydantic_schema:
output_parser: OutputParserLike = PydanticToolsParser(
tools=[schema], first_tool_only=True
)
else:
key_name = convert_to_openai_tool(schema)["function"]["name"]
output_parser = JsonOutputKeyToolsParser(
key_name=key_name, first_tool_only=True
key_name=tool_name, first_tool_only=True
)
elif method == "json_mode":
llm = self.bind(response_format={"type": "json_object"})
Expand Down

0 comments on commit af5255e

Please sign in to comment.