Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

openai[minor],core[minor]: Add support for passing strict in openai tools #6418

Merged
merged 14 commits into from
Aug 6, 2024
78 changes: 76 additions & 2 deletions docs/core_docs/docs/integrations/chat/openai.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@
},
{
"cell_type": "markdown",
"id": "3a5bb5ca-c3ae-4a58-be67-2cd18574b9a3",
"id": "bc5ecebd",
"metadata": {},
"source": [
"## Tool calling\n",
Expand All @@ -420,8 +420,82 @@
"\n",
"- [How to: disable parallel tool calling](/docs/how_to/tool_calling_parallel/)\n",
"- [How to: force a tool call](/docs/how_to/tool_choice/)\n",
"- [How to: bind model-specific tool formats to a model](/docs/how_to/tool_calling#binding-model-specific-formats-advanced).\n",
"- [How to: bind model-specific tool formats to a model](/docs/how_to/tool_calling#binding-model-specific-formats-advanced)."
]
},
{
"cell_type": "markdown",
"id": "3392390e",
"metadata": {},
"source": [
"### ``strict: true``\n",
"\n",
"```{=mdx}\n",
"\n",
":::info Requires ``@langchain/openai >= 0.2.6``\n",
"\n",
"As of Aug 6, 2024, OpenAI supports a `strict` argument when calling tools that will enforce that the tool argument schema is respected by the model. See more here: https://platform.openai.com/docs/guides/function-calling\n",
"\n",
"**Note**: If ``strict: true`` the tool definition will also be validated, and a subset of JSON schema are accepted. Crucially, schema cannot have optional args (those with default values). Read the full docs on what types of schema are supported here: https://platform.openai.com/docs/guides/structured-outputs/supported-schemas. \n",
":::\n",
"\n",
"\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "90f0d465",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[\n",
" {\n",
" name: 'get_current_weather',\n",
" args: { location: 'Hanoi' },\n",
" type: 'tool_call',\n",
" id: 'call_aB85ybkLCoccpzqHquuJGH3d'\n",
" }\n",
"]\n"
]
}
],
"source": [
"import { ChatOpenAI } from \"@langchain/openai\";\n",
"import { tool } from \"@langchain/core/tools\";\n",
"import { z } from \"zod\";\n",
"\n",
"const weatherTool = tool((_) => \"no-op\", {\n",
" name: \"get_current_weather\",\n",
" description: \"Get the current weather\",\n",
" schema: z.object({\n",
" location: z.string(),\n",
" }),\n",
"})\n",
"\n",
"const llmWithStrictTrue = new ChatOpenAI({\n",
" model: \"gpt-4o\",\n",
"}).bindTools([weatherTool], {\n",
" strict: true,\n",
" tool_choice: weatherTool.name,\n",
"});\n",
"\n",
"// Although the question is not about the weather, it will call the tool with the correct arguments\n",
"// because we passed `tool_choice` and `strict: true`.\n",
"const strictTrueResult = await llmWithStrictTrue.invoke(\"What is 127862 times 12898 divided by 2?\");\n",
"\n",
"console.dir(strictTrueResult.tool_calls, { depth: null });"
]
},
{
"cell_type": "markdown",
"id": "3a5bb5ca-c3ae-4a58-be67-2cd18574b9a3",
"metadata": {},
"source": [
"## API reference\n",
"\n",
"For detailed documentation of all ChatOpenAI features and configurations head to the API reference: https://api.js.langchain.com/classes/langchain_openai.ChatOpenAI.html"
Expand Down
9 changes: 9 additions & 0 deletions langchain-core/src/language_models/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,15 @@ export interface FunctionDefinition {
* how to call the function.
*/
description?: string;

/**
* Whether to enable strict schema adherence when generating the function call. If
* set to true, the model will follow the exact schema defined in the `parameters`
* field. Only a subset of JSON Schema is supported when `strict` is `true`. Learn
* more about Structured Outputs in the
* [function calling guide](https://platform.openai.com/docs/guides/function-calling).
*/
strict?: boolean;
}

export interface ToolDefinition {
Expand Down
62 changes: 57 additions & 5 deletions langchain-core/src/utils/function_calling.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,26 @@ import { Runnable, RunnableToolLike } from "../runnables/base.js";
* @returns {FunctionDefinition} The inputted tool in OpenAI function format.
*/
export function convertToOpenAIFunction(
tool: StructuredToolInterface | RunnableToolLike
tool: StructuredToolInterface | RunnableToolLike,
fields?:
| {
/**
* If `true`, model output is guaranteed to exactly match the JSON Schema
* provided in the function definition.
*/
strict?: boolean;
}
| number
): FunctionDefinition {
// @TODO 0.3.0 Remove the `number` typing
const fieldsCopy = typeof fields === "number" ? undefined : fields;

return {
name: tool.name,
description: tool.description,
parameters: zodToJsonSchema(tool.schema),
// Do not include the `strict` field if it is `undefined`.
...(fieldsCopy?.strict !== undefined ? { strict: fieldsCopy.strict } : {}),
};
}

Expand All @@ -34,15 +48,35 @@ export function convertToOpenAIFunction(
*/
export function convertToOpenAITool(
// eslint-disable-next-line @typescript-eslint/no-explicit-any
tool: StructuredToolInterface | Record<string, any> | RunnableToolLike
tool: StructuredToolInterface | Record<string, any> | RunnableToolLike,
fields?:
| {
/**
* If `true`, model output is guaranteed to exactly match the JSON Schema
* provided in the function definition.
*/
strict?: boolean;
}
| number
): ToolDefinition {
if (isStructuredTool(tool) || isRunnableToolLike(tool)) {
return {
// @TODO 0.3.0 Remove the `number` typing
const fieldsCopy = typeof fields === "number" ? undefined : fields;

let toolDef: ToolDefinition | undefined;
if (isLangChainTool(tool)) {
toolDef = {
type: "function",
function: convertToOpenAIFunction(tool),
};
} else {
toolDef = tool as ToolDefinition;
}

if (fieldsCopy?.strict !== undefined) {
toolDef.function.strict = fieldsCopy.strict;
}
return tool as ToolDefinition;

return toolDef;
}

/**
Expand Down Expand Up @@ -76,3 +110,21 @@ export function isRunnableToolLike(tool?: unknown): tool is RunnableToolLike {
tool.constructor.lc_name() === "RunnableToolLike"
);
}

/**
* Whether or not the tool is one of StructuredTool, RunnableTool or StructuredToolParams.
* It returns `is StructuredToolParams` since that is the most minimal interface of the three,
* while still containing the necessary properties to be passed to a LLM for tool calling.
*
* @param {unknown | undefined} tool The tool to check if it is a LangChain tool.
* @returns {tool is StructuredToolParams} Whether the inputted tool is a LangChain tool.
*/
export function isLangChainTool(
tool?: unknown
): tool is StructuredToolInterface {
return (
isRunnableToolLike(tool) ||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
isStructuredTool(tool as any)
);
}
4 changes: 3 additions & 1 deletion langchain/src/agents/openai_tools/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ export async function createOpenAIToolsAgent({
].join("\n")
);
}
const modelWithTools = llm.bind({ tools: tools.map(convertToOpenAITool) });
const modelWithTools = llm.bind({
tools: tools.map((tool) => convertToOpenAITool(tool)),
});
const agent = AgentRunnableSequence.fromRunnables(
[
RunnablePassthrough.assign({
Expand Down
2 changes: 1 addition & 1 deletion langchain/src/agents/openai_tools/output_parser.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ export type { ToolsAgentAction, ToolsAgentStep };
* new ChatOpenAI({
* modelName: "gpt-3.5-turbo-1106",
* temperature: 0,
* }).bind({ tools: tools.map(convertToOpenAITool) }),
* }).bind({ tools: tools.map((tool) => convertToOpenAITool(tool)) }),
* new OpenAIToolsAgentOutputParser(),
* ]).withConfig({ runName: "OpenAIToolsAgent" });
*
Expand Down
2 changes: 1 addition & 1 deletion libs/langchain-groq/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ export class ChatGroq extends BaseChatModel<
kwargs?: Partial<ChatGroqCallOptions>
): Runnable<BaseLanguageModelInput, AIMessageChunk, ChatGroqCallOptions> {
return this.bind({
tools: tools.map(convertToOpenAITool),
tools: tools.map((tool) => convertToOpenAITool(tool)),
...kwargs,
});
}
Expand Down
6 changes: 4 additions & 2 deletions libs/langchain-ollama/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ export class ChatOllama
kwargs?: Partial<this["ParsedCallOptions"]>
): Runnable<BaseLanguageModelInput, AIMessageChunk, ChatOllamaCallOptions> {
return this.bind({
tools: tools.map(convertToOpenAITool),
tools: tools.map((tool) => convertToOpenAITool(tool)),
...kwargs,
});
}
Expand Down Expand Up @@ -359,7 +359,9 @@ export class ChatOllama
stop: options?.stop,
},
tools: options?.tools?.length
? (options.tools.map(convertToOpenAITool) as OllamaTool[])
? (options.tools.map((tool) =>
convertToOpenAITool(tool)
) as OllamaTool[])
: undefined,
};
}
Expand Down
2 changes: 1 addition & 1 deletion libs/langchain-openai/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
"dependencies": {
"@langchain/core": ">=0.2.16 <0.3.0",
"js-tiktoken": "^1.0.12",
"openai": "^4.49.1",
"openai": "^4.55.0",
"zod": "^3.22.4",
"zod-to-json-schema": "^3.22.3"
},
Expand Down
Loading
Loading