From 49d85fdb114e9a91af5809675c5e206c8ca358e8 Mon Sep 17 00:00:00 2001 From: bracesproul Date: Tue, 6 Aug 2024 12:40:51 -0700 Subject: [PATCH 01/13] openai[minor],core[minor]: Add support for passing strict in openai tools --- langchain-core/src/language_models/base.ts | 9 + langchain-core/src/utils/function_calling.ts | 19 +- libs/langchain-openai/package.json | 2 +- libs/langchain-openai/src/chat_models.ts | 37 ++- .../src/tests/chat_models.test.ts | 212 ++++++++++++++++++ .../chat_models_structured_output.int.test.ts | 1 + libs/langchain-openai/src/types.ts | 7 + yarn.lock | 24 +- 8 files changed, 299 insertions(+), 12 deletions(-) create mode 100644 libs/langchain-openai/src/tests/chat_models.test.ts diff --git a/langchain-core/src/language_models/base.ts b/langchain-core/src/language_models/base.ts index 0e8af1bc32bf..cea8ca2f9ae3 100644 --- a/langchain-core/src/language_models/base.ts +++ b/langchain-core/src/language_models/base.ts @@ -233,6 +233,15 @@ export interface FunctionDefinition { * how to call the function. */ description?: string; + + /** + * Whether to enable strict schema adherence when generating the function call. If + * set to true, the model will follow the exact schema defined in the `parameters` + * field. Only a subset of JSON Schema is supported when `strict` is `true`. Learn + * more about Structured Outputs in the + * [function calling guide](https://platform.openai.com/docs/guides/function-calling). + */ + strict?: boolean; } export interface ToolDefinition { diff --git a/langchain-core/src/utils/function_calling.ts b/langchain-core/src/utils/function_calling.ts index 3871ffc4453d..6155f3cd8f8a 100644 --- a/langchain-core/src/utils/function_calling.ts +++ b/langchain-core/src/utils/function_calling.ts @@ -34,14 +34,29 @@ export function convertToOpenAIFunction( */ export function convertToOpenAITool( // eslint-disable-next-line @typescript-eslint/no-explicit-any - tool: StructuredToolInterface | Record | RunnableToolLike + tool: StructuredToolInterface | Record | RunnableToolLike, + fields?: { + /** + * If `true`, model output is guaranteed to exactly match the JSON Schema + * provided in the function definition. + */ + strict?: boolean; + } ): ToolDefinition { + let toolDef: ToolDefinition | undefined; if (isStructuredTool(tool) || isRunnableToolLike(tool)) { - return { + toolDef = { type: "function", function: convertToOpenAIFunction(tool), }; + } else { + toolDef = tool as ToolDefinition; + } + + if (fields?.strict !== undefined) { + toolDef.function.strict = fields.strict; } + return tool as ToolDefinition; } diff --git a/libs/langchain-openai/package.json b/libs/langchain-openai/package.json index 3115ef248c48..7a565308410c 100644 --- a/libs/langchain-openai/package.json +++ b/libs/langchain-openai/package.json @@ -37,7 +37,7 @@ "dependencies": { "@langchain/core": ">=0.2.16 <0.3.0", "js-tiktoken": "^1.0.12", - "openai": "^4.49.1", + "openai": "^4.55.0", "zod": "^3.22.4", "zod-to-json-schema": "^3.22.3" }, diff --git a/libs/langchain-openai/src/chat_models.ts b/libs/langchain-openai/src/chat_models.ts index db86e6e91940..af06da347149 100644 --- a/libs/langchain-openai/src/chat_models.ts +++ b/libs/langchain-openai/src/chat_models.ts @@ -1,4 +1,4 @@ -import { type ClientOptions, OpenAI as OpenAIClient } from "openai"; +import { type ClientOptions, OpenAI as OpenAIClient, } from "openai"; import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager"; import { @@ -299,6 +299,16 @@ export interface ChatOpenAICallOptions * call multiple tools in one response. */ parallel_tool_calls?: boolean; + /** + * If `true`, model output is guaranteed to exactly match the JSON Schema + * provided in the tool definition. + * Enabled by default for `"gpt-"` models. + */ + strict?: boolean; +} + +export interface ChatOpenAIFields extends Partial, Partial, BaseChatModelParams { + configuration?: ClientOptions & LegacyOpenAIInput; } /** @@ -441,12 +451,15 @@ export class ChatOpenAI< protected clientConfig: ClientOptions; + /** + * Whether the model supports the 'strict' argument when passing in tools. + * Defaults to `true` if `modelName`/`model` starts with 'gpt-' otherwise + * defaults to `false`. + */ + supportsStrictToolCalling?: boolean; + constructor( - fields?: Partial & - Partial & - BaseChatModelParams & { - configuration?: ClientOptions & LegacyOpenAIInput; - }, + fields?: ChatOpenAIFields, /** @deprecated */ configuration?: ClientOptions & LegacyOpenAIInput ) { @@ -541,6 +554,12 @@ export class ChatOpenAI< ...configuration, ...fields?.configuration, }; + + // Assume only "gpt-..." models support strict tool calling as of 08/06/24. + this.supportsStrictToolCalling = + fields?.supportsStrictToolCalling !== undefined + ? fields.supportsStrictToolCalling + : this.modelName.startsWith("gpt-"); } getLsParams(options: this["ParsedCallOptions"]): LangSmithParams { @@ -563,8 +582,9 @@ export class ChatOpenAI< )[], kwargs?: Partial ): Runnable { + const strict = kwargs?.strict !== undefined ? kwargs.strict : this.supportsStrictToolCalling; return this.bind({ - tools: tools.map(convertToOpenAITool), + tools: tools.map((tool) => convertToOpenAITool(tool, { strict })), ...kwargs, } as Partial); } @@ -578,6 +598,7 @@ export class ChatOpenAI< streaming?: boolean; } ): Omit { + const strict = options?.strict !== undefined ? options.strict : this.supportsStrictToolCalling; function isStructuredToolArray( tools?: unknown[] ): tools is StructuredToolInterface[] { @@ -615,7 +636,7 @@ export class ChatOpenAI< functions: options?.functions, function_call: options?.function_call, tools: isStructuredToolArray(options?.tools) - ? options?.tools.map(convertToOpenAITool) + ? options?.tools.map((tool) => convertToOpenAITool(tool, { strict })) : options?.tools, tool_choice: formatToOpenAIToolChoice(options?.tool_choice), response_format: options?.response_format, diff --git a/libs/langchain-openai/src/tests/chat_models.test.ts b/libs/langchain-openai/src/tests/chat_models.test.ts new file mode 100644 index 000000000000..3624ea148c77 --- /dev/null +++ b/libs/langchain-openai/src/tests/chat_models.test.ts @@ -0,0 +1,212 @@ +import { z } from "zod"; +import { zodToJsonSchema } from "zod-to-json-schema"; +import { it, expect, describe, beforeAll, afterAll, jest } from "@jest/globals"; +import { ChatOpenAI } from "../chat_models.js"; + + +describe("strict tool calling", () => { + const weatherTool = { + type: "function" as const, + function: { + name: "get_current_weather", + description: "Get the current weather in a location", + parameters: zodToJsonSchema(z.object({ + location: z.string().describe("The location to get the weather for"), + })) + } + } + + // Store the original value of LANGCHAIN_TRACING_V2 + let oldLangChainTracingValue: string | undefined; + // Before all tests, save the current LANGCHAIN_TRACING_V2 value + beforeAll(() => { + oldLangChainTracingValue = process.env.LANGCHAIN_TRACING_V2; + }) + // After all tests, restore the original LANGCHAIN_TRACING_V2 value + afterAll(() => { + if (oldLangChainTracingValue !== undefined) { + process.env.LANGCHAIN_TRACING_V2 = oldLangChainTracingValue; + } else { + // If it was undefined, remove the environment variable + delete process.env.LANGCHAIN_TRACING_V2; + } + }) + + it("Can accept strict as a call arg via .bindTools", async () => { + const mockFetch = jest.fn<(url: any, init?: any) => Promise>(); + mockFetch.mockImplementation((url, options): Promise => { + // Store the request details for later inspection + mockFetch.mock.calls.push({ url, options } as any); + + // Return a mock response + return Promise.resolve({ + ok: true, + json: () => Promise.resolve({}), + }) as Promise; + }); + + const model = new ChatOpenAI({ + model: "gpt-4", + configuration: { + fetch: mockFetch, + }, + maxRetries: 0, + }); + + const modelWithTools = model.bindTools([weatherTool], { strict: true }); + + // This will fail since we're not returning a valid response in our mocked fetch function. + await expect(modelWithTools.invoke("What's the weather like?")).rejects.toThrow(); + + expect(mockFetch).toHaveBeenCalled(); + const [_url, options] = mockFetch.mock.calls[0]; + + if (options && options.body) { + expect(JSON.parse(options.body).tools).toEqual([expect.objectContaining({ + type: "function", + function: { + ...weatherTool.function, + // This should be added to the function call because `strict` was passed to `bindTools` + strict: true, + } + })]); + } else { + throw new Error("Body not found in request.") + } + }); + + it("Can accept strict as a call arg via .bind", async () => { + const mockFetch = jest.fn<(url: any, init?: any) => Promise>(); + mockFetch.mockImplementation((url, options): Promise => { + // Store the request details for later inspection + mockFetch.mock.calls.push({ url, options } as any); + + // Return a mock response + return Promise.resolve({ + ok: true, + json: () => Promise.resolve({}), + }) as Promise; + }); + + const model = new ChatOpenAI({ + model: "gpt-4", + configuration: { + fetch: mockFetch, + }, + maxRetries: 0, + }); + + const modelWithTools = model.bind({ + tools: [weatherTool], + strict: true + }); + + // This will fail since we're not returning a valid response in our mocked fetch function. + await expect(modelWithTools.invoke("What's the weather like?")).rejects.toThrow(); + + expect(mockFetch).toHaveBeenCalled(); + const [_url, options] = mockFetch.mock.calls[0]; + + if (options && options.body) { + expect(JSON.parse(options.body).tools).toEqual([expect.objectContaining({ + type: "function", + function: { + ...weatherTool.function, + // This should be added to the function call because `strict` was passed to `bind` + strict: true, + } + })]); + } else { + throw new Error("Body not found in request.") + } + }); + + it("Sets strict to true if the model name starts with 'gpt-'", async () => { + const mockFetch = jest.fn<(url: any, init?: any) => Promise>(); + mockFetch.mockImplementation((url, options): Promise => { + // Store the request details for later inspection + mockFetch.mock.calls.push({ url, options } as any); + + // Return a mock response + return Promise.resolve({ + ok: true, + json: () => Promise.resolve({}), + }) as Promise; + }); + + const model = new ChatOpenAI({ + model: "gpt-4", + configuration: { + fetch: mockFetch, + }, + maxRetries: 0, + }); + + // Do NOT pass `strict` here since we're checking that it's set to true by default + const modelWithTools = model.bindTools([weatherTool]); + + // This will fail since we're not returning a valid response in our mocked fetch function. + await expect(modelWithTools.invoke("What's the weather like?")).rejects.toThrow(); + + expect(mockFetch).toHaveBeenCalled(); + const [_url, options] = mockFetch.mock.calls[0]; + + if (options && options.body) { + expect(JSON.parse(options.body).tools).toEqual([expect.objectContaining({ + type: "function", + function: { + ...weatherTool.function, + // This should be added to the function call because `strict` was passed to `bind` + strict: true, + } + })]); + } else { + throw new Error("Body not found in request.") + } + }); + + it("Strict is false if supportsStrictToolCalling is false", async () => { + const mockFetch = jest.fn<(url: any, init?: any) => Promise>(); + mockFetch.mockImplementation((url, options): Promise => { + // Store the request details for later inspection + mockFetch.mock.calls.push({ url, options } as any); + + // Return a mock response + return Promise.resolve({ + ok: true, + json: () => Promise.resolve({}), + }) as Promise; + }); + + const model = new ChatOpenAI({ + model: "gpt-4", + configuration: { + fetch: mockFetch, + }, + maxRetries: 0, + supportsStrictToolCalling: false, + }); + + // Do NOT pass `strict` here since we're checking that it's set to true by default + const modelWithTools = model.bindTools([weatherTool]); + + // This will fail since we're not returning a valid response in our mocked fetch function. + await expect(modelWithTools.invoke("What's the weather like?")).rejects.toThrow(); + + expect(mockFetch).toHaveBeenCalled(); + const [_url, options] = mockFetch.mock.calls[0]; + + if (options && options.body) { + expect(JSON.parse(options.body).tools).toEqual([expect.objectContaining({ + type: "function", + function: { + ...weatherTool.function, + // This should be added to the function call because `strict` was passed to `bind` + strict: false, + } + })]); + } else { + throw new Error("Body not found in request.") + } + }); +}) diff --git a/libs/langchain-openai/src/tests/chat_models_structured_output.int.test.ts b/libs/langchain-openai/src/tests/chat_models_structured_output.int.test.ts index 86bf0247bd49..95f379c8696b 100644 --- a/libs/langchain-openai/src/tests/chat_models_structured_output.int.test.ts +++ b/libs/langchain-openai/src/tests/chat_models_structured_output.int.test.ts @@ -3,6 +3,7 @@ import { zodToJsonSchema } from "zod-to-json-schema"; import { ChatPromptTemplate } from "@langchain/core/prompts"; import { AIMessage } from "@langchain/core/messages"; import { ChatOpenAI } from "../chat_models.js"; +import { test, expect } from "@jest/globals"; test("withStructuredOutput zod schema function calling", async () => { const model = new ChatOpenAI({ diff --git a/libs/langchain-openai/src/types.ts b/libs/langchain-openai/src/types.ts index 19e6af483d7d..afd9fea2624b 100644 --- a/libs/langchain-openai/src/types.ts +++ b/libs/langchain-openai/src/types.ts @@ -155,6 +155,13 @@ export interface OpenAIChatInput extends OpenAIBaseInput { * Currently in experimental beta. */ __includeRawResponse?: boolean; + + /** + * Whether the model supports the 'strict' argument when passing in tools. + * Defaults to `true` if `modelName`/`model` starts with 'gpt-' otherwise + * defaults to `false`. + */ + supportsStrictToolCalling?: boolean; } export declare interface AzureOpenAIInput { diff --git a/yarn.lock b/yarn.lock index c96e97605695..075329f22ffb 100644 --- a/yarn.lock +++ b/yarn.lock @@ -12199,7 +12199,7 @@ __metadata: jest: ^29.5.0 jest-environment-node: ^29.6.4 js-tiktoken: ^1.0.12 - openai: ^4.49.1 + openai: ^4.55.0 prettier: ^2.8.3 release-it: ^17.6.0 rimraf: ^5.0.1 @@ -34040,6 +34040,28 @@ __metadata: languageName: node linkType: hard +"openai@npm:^4.55.0": + version: 4.55.0 + resolution: "openai@npm:4.55.0" + dependencies: + "@types/node": ^18.11.18 + "@types/node-fetch": ^2.6.4 + abort-controller: ^3.0.0 + agentkeepalive: ^4.2.1 + form-data-encoder: 1.7.2 + formdata-node: ^4.3.2 + node-fetch: ^2.6.7 + peerDependencies: + zod: ^3.23.8 + peerDependenciesMeta: + zod: + optional: true + bin: + openai: bin/cli + checksum: b2b1daa976516262e08e182ee982976a1dc615eebd250bbd71f4122740ebeeb207a20af6d35c718b67f1c3457196b524667a0c7fa417ab4e119020b5c1f5cd74 + languageName: node + linkType: hard + "openapi-types@npm:^12.1.3": version: 12.1.3 resolution: "openapi-types@npm:12.1.3" From 9808646048ded99bf508ceb6f975e4e5ad45b5d6 Mon Sep 17 00:00:00 2001 From: bracesproul Date: Tue, 6 Aug 2024 12:45:06 -0700 Subject: [PATCH 02/13] add integration test --- .../chat_models_structured_output.int.test.ts | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/libs/langchain-openai/src/tests/chat_models_structured_output.int.test.ts b/libs/langchain-openai/src/tests/chat_models_structured_output.int.test.ts index 95f379c8696b..5d4d4a13af85 100644 --- a/libs/langchain-openai/src/tests/chat_models_structured_output.int.test.ts +++ b/libs/langchain-openai/src/tests/chat_models_structured_output.int.test.ts @@ -321,3 +321,30 @@ test("parallelToolCalls param", async () => { // console.log(response.tool_calls); expect(response.tool_calls?.length).toBe(1); }); + +test("Passing strict true forces the model to conform to the schema", async () => { + const model = new ChatOpenAI({ + model: "gpt-4o", + temperature: 0, + maxRetries: 0, + }); + + const weatherTool = { + type: "function" as const, + function: { + name: "get_current_weather", + description: "Get the current weather in a location", + parameters: zodToJsonSchema(z.object({ + location: z.string().describe("The location to get the weather for"), + })) + } + } + const modelWithTools = model.bindTools([weatherTool], { strict: true, tool_choice: "get_current_weather" }); + + const result = await modelWithTools.invoke("Whats the result of 173827 times 287326 divided by 2?"); + // Expect at least one tool call, allow multiple + expect(result.tool_calls?.length).toBeGreaterThanOrEqual(1); + expect(result.tool_calls?.[0].name).toBe("get_current_weather"); + expect(result.tool_calls?.[0].args).toHaveProperty("location"); + console.log(result.tool_calls?.[0].args) +}) \ No newline at end of file From 9c94f4e2f3ed6aa3fcf112c8c41baddc5e2a5027 Mon Sep 17 00:00:00 2001 From: bracesproul Date: Tue, 6 Aug 2024 12:45:52 -0700 Subject: [PATCH 03/13] chore: lint files --- libs/langchain-openai/src/chat_models.ts | 25 ++-- .../src/tests/chat_models.test.ts | 133 ++++++++++-------- .../chat_models_structured_output.int.test.ts | 27 ++-- 3 files changed, 109 insertions(+), 76 deletions(-) diff --git a/libs/langchain-openai/src/chat_models.ts b/libs/langchain-openai/src/chat_models.ts index af06da347149..3c9d43e4b540 100644 --- a/libs/langchain-openai/src/chat_models.ts +++ b/libs/langchain-openai/src/chat_models.ts @@ -1,4 +1,4 @@ -import { type ClientOptions, OpenAI as OpenAIClient, } from "openai"; +import { type ClientOptions, OpenAI as OpenAIClient } from "openai"; import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager"; import { @@ -307,7 +307,10 @@ export interface ChatOpenAICallOptions strict?: boolean; } -export interface ChatOpenAIFields extends Partial, Partial, BaseChatModelParams { +export interface ChatOpenAIFields + extends Partial, + Partial, + BaseChatModelParams { configuration?: ClientOptions & LegacyOpenAIInput; } @@ -556,10 +559,10 @@ export class ChatOpenAI< }; // Assume only "gpt-..." models support strict tool calling as of 08/06/24. - this.supportsStrictToolCalling = - fields?.supportsStrictToolCalling !== undefined - ? fields.supportsStrictToolCalling - : this.modelName.startsWith("gpt-"); + this.supportsStrictToolCalling = + fields?.supportsStrictToolCalling !== undefined + ? fields.supportsStrictToolCalling + : this.modelName.startsWith("gpt-"); } getLsParams(options: this["ParsedCallOptions"]): LangSmithParams { @@ -582,7 +585,10 @@ export class ChatOpenAI< )[], kwargs?: Partial ): Runnable { - const strict = kwargs?.strict !== undefined ? kwargs.strict : this.supportsStrictToolCalling; + const strict = + kwargs?.strict !== undefined + ? kwargs.strict + : this.supportsStrictToolCalling; return this.bind({ tools: tools.map((tool) => convertToOpenAITool(tool, { strict })), ...kwargs, @@ -598,7 +604,10 @@ export class ChatOpenAI< streaming?: boolean; } ): Omit { - const strict = options?.strict !== undefined ? options.strict : this.supportsStrictToolCalling; + const strict = + options?.strict !== undefined + ? options.strict + : this.supportsStrictToolCalling; function isStructuredToolArray( tools?: unknown[] ): tools is StructuredToolInterface[] { diff --git a/libs/langchain-openai/src/tests/chat_models.test.ts b/libs/langchain-openai/src/tests/chat_models.test.ts index 3624ea148c77..a51626a00c0e 100644 --- a/libs/langchain-openai/src/tests/chat_models.test.ts +++ b/libs/langchain-openai/src/tests/chat_models.test.ts @@ -3,25 +3,26 @@ import { zodToJsonSchema } from "zod-to-json-schema"; import { it, expect, describe, beforeAll, afterAll, jest } from "@jest/globals"; import { ChatOpenAI } from "../chat_models.js"; - describe("strict tool calling", () => { const weatherTool = { type: "function" as const, function: { name: "get_current_weather", description: "Get the current weather in a location", - parameters: zodToJsonSchema(z.object({ - location: z.string().describe("The location to get the weather for"), - })) - } - } + parameters: zodToJsonSchema( + z.object({ + location: z.string().describe("The location to get the weather for"), + }) + ), + }, + }; // Store the original value of LANGCHAIN_TRACING_V2 let oldLangChainTracingValue: string | undefined; // Before all tests, save the current LANGCHAIN_TRACING_V2 value beforeAll(() => { oldLangChainTracingValue = process.env.LANGCHAIN_TRACING_V2; - }) + }); // After all tests, restore the original LANGCHAIN_TRACING_V2 value afterAll(() => { if (oldLangChainTracingValue !== undefined) { @@ -30,14 +31,14 @@ describe("strict tool calling", () => { // If it was undefined, remove the environment variable delete process.env.LANGCHAIN_TRACING_V2; } - }) + }); it("Can accept strict as a call arg via .bindTools", async () => { const mockFetch = jest.fn<(url: any, init?: any) => Promise>(); mockFetch.mockImplementation((url, options): Promise => { // Store the request details for later inspection mockFetch.mock.calls.push({ url, options } as any); - + // Return a mock response return Promise.resolve({ ok: true, @@ -56,22 +57,26 @@ describe("strict tool calling", () => { const modelWithTools = model.bindTools([weatherTool], { strict: true }); // This will fail since we're not returning a valid response in our mocked fetch function. - await expect(modelWithTools.invoke("What's the weather like?")).rejects.toThrow(); + await expect( + modelWithTools.invoke("What's the weather like?") + ).rejects.toThrow(); expect(mockFetch).toHaveBeenCalled(); const [_url, options] = mockFetch.mock.calls[0]; - + if (options && options.body) { - expect(JSON.parse(options.body).tools).toEqual([expect.objectContaining({ - type: "function", - function: { - ...weatherTool.function, - // This should be added to the function call because `strict` was passed to `bindTools` - strict: true, - } - })]); + expect(JSON.parse(options.body).tools).toEqual([ + expect.objectContaining({ + type: "function", + function: { + ...weatherTool.function, + // This should be added to the function call because `strict` was passed to `bindTools` + strict: true, + }, + }), + ]); } else { - throw new Error("Body not found in request.") + throw new Error("Body not found in request."); } }); @@ -80,7 +85,7 @@ describe("strict tool calling", () => { mockFetch.mockImplementation((url, options): Promise => { // Store the request details for later inspection mockFetch.mock.calls.push({ url, options } as any); - + // Return a mock response return Promise.resolve({ ok: true, @@ -98,26 +103,30 @@ describe("strict tool calling", () => { const modelWithTools = model.bind({ tools: [weatherTool], - strict: true + strict: true, }); // This will fail since we're not returning a valid response in our mocked fetch function. - await expect(modelWithTools.invoke("What's the weather like?")).rejects.toThrow(); + await expect( + modelWithTools.invoke("What's the weather like?") + ).rejects.toThrow(); expect(mockFetch).toHaveBeenCalled(); const [_url, options] = mockFetch.mock.calls[0]; - + if (options && options.body) { - expect(JSON.parse(options.body).tools).toEqual([expect.objectContaining({ - type: "function", - function: { - ...weatherTool.function, - // This should be added to the function call because `strict` was passed to `bind` - strict: true, - } - })]); + expect(JSON.parse(options.body).tools).toEqual([ + expect.objectContaining({ + type: "function", + function: { + ...weatherTool.function, + // This should be added to the function call because `strict` was passed to `bind` + strict: true, + }, + }), + ]); } else { - throw new Error("Body not found in request.") + throw new Error("Body not found in request."); } }); @@ -126,7 +135,7 @@ describe("strict tool calling", () => { mockFetch.mockImplementation((url, options): Promise => { // Store the request details for later inspection mockFetch.mock.calls.push({ url, options } as any); - + // Return a mock response return Promise.resolve({ ok: true, @@ -146,22 +155,26 @@ describe("strict tool calling", () => { const modelWithTools = model.bindTools([weatherTool]); // This will fail since we're not returning a valid response in our mocked fetch function. - await expect(modelWithTools.invoke("What's the weather like?")).rejects.toThrow(); + await expect( + modelWithTools.invoke("What's the weather like?") + ).rejects.toThrow(); expect(mockFetch).toHaveBeenCalled(); const [_url, options] = mockFetch.mock.calls[0]; - + if (options && options.body) { - expect(JSON.parse(options.body).tools).toEqual([expect.objectContaining({ - type: "function", - function: { - ...weatherTool.function, - // This should be added to the function call because `strict` was passed to `bind` - strict: true, - } - })]); + expect(JSON.parse(options.body).tools).toEqual([ + expect.objectContaining({ + type: "function", + function: { + ...weatherTool.function, + // This should be added to the function call because `strict` was passed to `bind` + strict: true, + }, + }), + ]); } else { - throw new Error("Body not found in request.") + throw new Error("Body not found in request."); } }); @@ -170,7 +183,7 @@ describe("strict tool calling", () => { mockFetch.mockImplementation((url, options): Promise => { // Store the request details for later inspection mockFetch.mock.calls.push({ url, options } as any); - + // Return a mock response return Promise.resolve({ ok: true, @@ -191,22 +204,26 @@ describe("strict tool calling", () => { const modelWithTools = model.bindTools([weatherTool]); // This will fail since we're not returning a valid response in our mocked fetch function. - await expect(modelWithTools.invoke("What's the weather like?")).rejects.toThrow(); + await expect( + modelWithTools.invoke("What's the weather like?") + ).rejects.toThrow(); expect(mockFetch).toHaveBeenCalled(); const [_url, options] = mockFetch.mock.calls[0]; - + if (options && options.body) { - expect(JSON.parse(options.body).tools).toEqual([expect.objectContaining({ - type: "function", - function: { - ...weatherTool.function, - // This should be added to the function call because `strict` was passed to `bind` - strict: false, - } - })]); + expect(JSON.parse(options.body).tools).toEqual([ + expect.objectContaining({ + type: "function", + function: { + ...weatherTool.function, + // This should be added to the function call because `strict` was passed to `bind` + strict: false, + }, + }), + ]); } else { - throw new Error("Body not found in request.") + throw new Error("Body not found in request."); } }); -}) +}); diff --git a/libs/langchain-openai/src/tests/chat_models_structured_output.int.test.ts b/libs/langchain-openai/src/tests/chat_models_structured_output.int.test.ts index 5d4d4a13af85..bc0328357a73 100644 --- a/libs/langchain-openai/src/tests/chat_models_structured_output.int.test.ts +++ b/libs/langchain-openai/src/tests/chat_models_structured_output.int.test.ts @@ -2,8 +2,8 @@ import { z } from "zod"; import { zodToJsonSchema } from "zod-to-json-schema"; import { ChatPromptTemplate } from "@langchain/core/prompts"; import { AIMessage } from "@langchain/core/messages"; -import { ChatOpenAI } from "../chat_models.js"; import { test, expect } from "@jest/globals"; +import { ChatOpenAI } from "../chat_models.js"; test("withStructuredOutput zod schema function calling", async () => { const model = new ChatOpenAI({ @@ -334,17 +334,24 @@ test("Passing strict true forces the model to conform to the schema", async () = function: { name: "get_current_weather", description: "Get the current weather in a location", - parameters: zodToJsonSchema(z.object({ - location: z.string().describe("The location to get the weather for"), - })) - } - } - const modelWithTools = model.bindTools([weatherTool], { strict: true, tool_choice: "get_current_weather" }); + parameters: zodToJsonSchema( + z.object({ + location: z.string().describe("The location to get the weather for"), + }) + ), + }, + }; + const modelWithTools = model.bindTools([weatherTool], { + strict: true, + tool_choice: "get_current_weather", + }); - const result = await modelWithTools.invoke("Whats the result of 173827 times 287326 divided by 2?"); + const result = await modelWithTools.invoke( + "Whats the result of 173827 times 287326 divided by 2?" + ); // Expect at least one tool call, allow multiple expect(result.tool_calls?.length).toBeGreaterThanOrEqual(1); expect(result.tool_calls?.[0].name).toBe("get_current_weather"); expect(result.tool_calls?.[0].args).toHaveProperty("location"); - console.log(result.tool_calls?.[0].args) -}) \ No newline at end of file + console.log(result.tool_calls?.[0].args); +}); From 82e0f94c27c2fabe2fb71a7c94c15fb59dbab25a Mon Sep 17 00:00:00 2001 From: bracesproul Date: Tue, 6 Aug 2024 12:52:10 -0700 Subject: [PATCH 04/13] Cr --- libs/langchain-openai/src/tests/chat_models.test.ts | 1 + 1 file changed, 1 insertion(+) diff --git a/libs/langchain-openai/src/tests/chat_models.test.ts b/libs/langchain-openai/src/tests/chat_models.test.ts index a51626a00c0e..6ae9dd4bc00b 100644 --- a/libs/langchain-openai/src/tests/chat_models.test.ts +++ b/libs/langchain-openai/src/tests/chat_models.test.ts @@ -1,3 +1,4 @@ +/* eslint-disable @typescript-eslint/no-explicit-any, no-process-env */ import { z } from "zod"; import { zodToJsonSchema } from "zod-to-json-schema"; import { it, expect, describe, beforeAll, afterAll, jest } from "@jest/globals"; From 7a65cdd5e57a714f79066ac5af75ba668ffa2848 Mon Sep 17 00:00:00 2001 From: bracesproul Date: Tue, 6 Aug 2024 13:57:55 -0700 Subject: [PATCH 05/13] cr --- langchain-core/src/utils/function_calling.ts | 11 +- libs/langchain-openai/src/chat_models.ts | 93 +++++++--- .../src/tests/chat_models.test.ts | 161 ++++++++++++++++-- 3 files changed, 229 insertions(+), 36 deletions(-) diff --git a/langchain-core/src/utils/function_calling.ts b/langchain-core/src/utils/function_calling.ts index 6155f3cd8f8a..11bba5c8d62b 100644 --- a/langchain-core/src/utils/function_calling.ts +++ b/langchain-core/src/utils/function_calling.ts @@ -13,12 +13,21 @@ import { Runnable, RunnableToolLike } from "../runnables/base.js"; * @returns {FunctionDefinition} The inputted tool in OpenAI function format. */ export function convertToOpenAIFunction( - tool: StructuredToolInterface | RunnableToolLike + tool: StructuredToolInterface | RunnableToolLike, + fields?: { + /** + * If `true`, model output is guaranteed to exactly match the JSON Schema + * provided in the function definition. + */ + strict?: boolean; + } ): FunctionDefinition { return { name: tool.name, description: tool.description, parameters: zodToJsonSchema(tool.schema), + // Do not include the `strict` field if it is `undefined`. + ...(fields?.strict !== undefined ? { strict: fields.strict } : {}), }; } diff --git a/libs/langchain-openai/src/chat_models.ts b/libs/langchain-openai/src/chat_models.ts index 3c9d43e4b540..c281cc16794e 100644 --- a/libs/langchain-openai/src/chat_models.ts +++ b/libs/langchain-openai/src/chat_models.ts @@ -274,6 +274,26 @@ function convertMessagesToOpenAIParams(messages: BaseMessage[]) { }); } +export interface ChatOpenAIStructuredOutputMethodOptions< + IncludeRaw extends boolean +> extends StructuredOutputMethodOptions { + /** + * strict: If `true` and `method` = "function_calling", model output is + * guaranteed to exactly match the schema. If `true`, the input schema + * will also be validated according to + * https://platform.openai.com/docs/guides/structured-outputs/supported-schemas. + * If `false`, input schema will not be validated and model output will not + * be validated. + * If `undefined`, `strict` argument will not be passed to the model. + * + * @version 0.2.6 + * @note Planned breaking change in version `0.3.0`: + * `strict` will default to `true` when `method` is + * "function_calling" as of version `0.3.0`. + */ + strict?: boolean; +} + export interface ChatOpenAICallOptions extends OpenAICallOptions, BaseFunctionCallOptions { @@ -301,8 +321,18 @@ export interface ChatOpenAICallOptions parallel_tool_calls?: boolean; /** * If `true`, model output is guaranteed to exactly match the JSON Schema - * provided in the tool definition. + * provided in the tool definition. If `true`, the input schema will also be + * validated according to + * https://platform.openai.com/docs/guides/structured-outputs/supported-schemas. + * + * If `false`, input schema will not be validated and model output will not + * be validated. + * + * If `undefined`, `strict` argument will not be passed to the model. + * * Enabled by default for `"gpt-"` models. + * + * @version 0.2.6 */ strict?: boolean; } @@ -455,9 +485,10 @@ export class ChatOpenAI< protected clientConfig: ClientOptions; /** - * Whether the model supports the 'strict' argument when passing in tools. + * Whether the model supports the `strict` argument when passing in tools. * Defaults to `true` if `modelName`/`model` starts with 'gpt-' otherwise - * defaults to `false`. + * defaults to `undefined`. If `undefined` the `strict` argument will not + * be passed to OpenAI. */ supportsStrictToolCalling?: boolean; @@ -559,10 +590,13 @@ export class ChatOpenAI< }; // Assume only "gpt-..." models support strict tool calling as of 08/06/24. - this.supportsStrictToolCalling = - fields?.supportsStrictToolCalling !== undefined - ? fields.supportsStrictToolCalling - : this.modelName.startsWith("gpt-"); + // If `supportsStrictToolCalling` is explicitly set, use that value, or `true` + // if the model name starts with "gpt-". Else leave undefined so it's not passed to OpenAI. + if (fields?.supportsStrictToolCalling !== undefined) { + this.supportsStrictToolCalling = fields.supportsStrictToolCalling; + } else if (this.modelName.startsWith("gpt-")) { + this.supportsStrictToolCalling = true; + } } getLsParams(options: this["ParsedCallOptions"]): LangSmithParams { @@ -585,10 +619,12 @@ export class ChatOpenAI< )[], kwargs?: Partial ): Runnable { - const strict = - kwargs?.strict !== undefined - ? kwargs.strict - : this.supportsStrictToolCalling; + let strict: boolean | undefined; + if (kwargs?.strict !== undefined) { + strict = kwargs.strict; + } else if (this.supportsStrictToolCalling !== undefined) { + strict = this.supportsStrictToolCalling; + } return this.bind({ tools: tools.map((tool) => convertToOpenAITool(tool, { strict })), ...kwargs, @@ -604,10 +640,13 @@ export class ChatOpenAI< streaming?: boolean; } ): Omit { - const strict = - options?.strict !== undefined - ? options.strict - : this.supportsStrictToolCalling; + let strict: boolean | undefined; + if (options?.strict !== undefined) { + strict = options.strict; + } else if (this.supportsStrictToolCalling !== undefined) { + strict = this.supportsStrictToolCalling; + } + function isStructuredToolArray( tools?: unknown[] ): tools is StructuredToolInterface[] { @@ -646,7 +685,13 @@ export class ChatOpenAI< function_call: options?.function_call, tools: isStructuredToolArray(options?.tools) ? options?.tools.map((tool) => convertToOpenAITool(tool, { strict })) - : options?.tools, + : options?.tools?.map((tool) => { + const toolCopy = { ...tool }; + if (strict !== undefined) { + toolCopy.function.strict = strict; + } + return toolCopy; + }), tool_choice: formatToOpenAIToolChoice(options?.tool_choice), response_format: options?.response_format, seed: options?.seed, @@ -1128,7 +1173,7 @@ export class ChatOpenAI< | z.ZodType // eslint-disable-next-line @typescript-eslint/no-explicit-any | Record, - config?: StructuredOutputMethodOptions + config?: ChatOpenAIStructuredOutputMethodOptions ): Runnable; withStructuredOutput< @@ -1140,7 +1185,7 @@ export class ChatOpenAI< | z.ZodType // eslint-disable-next-line @typescript-eslint/no-explicit-any | Record, - config?: StructuredOutputMethodOptions + config?: ChatOpenAIStructuredOutputMethodOptions ): Runnable; withStructuredOutput< @@ -1152,7 +1197,7 @@ export class ChatOpenAI< | z.ZodType // eslint-disable-next-line @typescript-eslint/no-explicit-any | Record, - config?: StructuredOutputMethodOptions + config?: ChatOpenAIStructuredOutputMethodOptions ): | Runnable | Runnable< @@ -1178,6 +1223,12 @@ export class ChatOpenAI< let llm: Runnable; let outputParser: BaseLLMOutputParser; + if (config?.strict !== undefined && method === "jsonMode") { + throw new Error( + "Argument `strict` is only supported for `method` = 'function_calling'" + ); + } + if (method === "jsonMode") { llm = this.bind({ response_format: { type: "json_object" }, @@ -1209,6 +1260,8 @@ export class ChatOpenAI< name: functionName, }, }, + // Do not pass `strict` argument to OpenAI if `config.strict` is undefined + ...(config?.strict !== undefined ? { strict: config.strict } : {}), } as Partial); outputParser = new JsonOutputKeyToolsParser({ returnSingle: true, @@ -1245,6 +1298,8 @@ export class ChatOpenAI< name: functionName, }, }, + // Do not pass `strict` argument to OpenAI if `config.strict` is undefined + ...(config?.strict !== undefined ? { strict: config.strict } : {}), } as Partial); outputParser = new JsonOutputKeyToolsParser({ returnSingle: true, diff --git a/libs/langchain-openai/src/tests/chat_models.test.ts b/libs/langchain-openai/src/tests/chat_models.test.ts index 6ae9dd4bc00b..a2ca273c5c65 100644 --- a/libs/langchain-openai/src/tests/chat_models.test.ts +++ b/libs/langchain-openai/src/tests/chat_models.test.ts @@ -35,16 +35,16 @@ describe("strict tool calling", () => { }); it("Can accept strict as a call arg via .bindTools", async () => { - const mockFetch = jest.fn<(url: any, init?: any) => Promise>(); - mockFetch.mockImplementation((url, options): Promise => { + const mockFetch = jest.fn<(url: any, options?: any) => Promise>(); + mockFetch.mockImplementation((url, options) => { // Store the request details for later inspection - mockFetch.mock.calls.push({ url, options } as any); + mockFetch.mock.calls.push([url, options]); // Return a mock response return Promise.resolve({ ok: true, json: () => Promise.resolve({}), - }) as Promise; + }); }); const model = new ChatOpenAI({ @@ -82,16 +82,16 @@ describe("strict tool calling", () => { }); it("Can accept strict as a call arg via .bind", async () => { - const mockFetch = jest.fn<(url: any, init?: any) => Promise>(); - mockFetch.mockImplementation((url, options): Promise => { + const mockFetch = jest.fn<(url: any, options?: any) => Promise>(); + mockFetch.mockImplementation((url, options) => { // Store the request details for later inspection - mockFetch.mock.calls.push({ url, options } as any); + mockFetch.mock.calls.push([url, options]); // Return a mock response return Promise.resolve({ ok: true, json: () => Promise.resolve({}), - }) as Promise; + }); }); const model = new ChatOpenAI({ @@ -132,16 +132,16 @@ describe("strict tool calling", () => { }); it("Sets strict to true if the model name starts with 'gpt-'", async () => { - const mockFetch = jest.fn<(url: any, init?: any) => Promise>(); - mockFetch.mockImplementation((url, options): Promise => { + const mockFetch = jest.fn<(url: any, options?: any) => Promise>(); + mockFetch.mockImplementation((url, options) => { // Store the request details for later inspection - mockFetch.mock.calls.push({ url, options } as any); + mockFetch.mock.calls.push([url, options]); // Return a mock response return Promise.resolve({ ok: true, json: () => Promise.resolve({}), - }) as Promise; + }); }); const model = new ChatOpenAI({ @@ -180,16 +180,16 @@ describe("strict tool calling", () => { }); it("Strict is false if supportsStrictToolCalling is false", async () => { - const mockFetch = jest.fn<(url: any, init?: any) => Promise>(); - mockFetch.mockImplementation((url, options): Promise => { + const mockFetch = jest.fn<(url: any, options?: any) => Promise>(); + mockFetch.mockImplementation((url, options) => { // Store the request details for later inspection - mockFetch.mock.calls.push({ url, options } as any); + mockFetch.mock.calls.push([url, options]); // Return a mock response return Promise.resolve({ ok: true, json: () => Promise.resolve({}), - }) as Promise; + }); }); const model = new ChatOpenAI({ @@ -227,4 +227,133 @@ describe("strict tool calling", () => { throw new Error("Body not found in request."); } }); + + // test fails unless it's run in isolation + it.skip("Strict is not passed if non 'gpt-' model is passed.", async () => { + const mockFetch = jest.fn<(url: any, options?: any) => Promise>(); + mockFetch.mockImplementation((url, options) => { + // Store the request details for later inspection + mockFetch.mock.calls.push([url, options]); + + // Return a mock response + return Promise.resolve({ + ok: true, + json: () => Promise.resolve({}), + }); + }); + + const model = new ChatOpenAI({ + model: "doesnt-start-with-gpt-4", + configuration: { + fetch: mockFetch, + }, + maxRetries: 0, + }); + + const modelWithTools = model.bindTools([weatherTool]); + + // This will fail since we're not returning a valid response in our mocked fetch function. + await expect( + modelWithTools.invoke("What's the weather like?") + ).rejects.toThrow(); + + expect(mockFetch).toHaveBeenCalled(); + const [_url, options] = mockFetch.mock.calls[0]; + + if (options && options.body) { + const body = JSON.parse(options.body); + expect(body.tools[0].function).not.toHaveProperty("strict"); + } else { + throw new Error("Body not found in request."); + } + }); + + it("Strict is set to true if passed in .withStructuredOutput", async () => { + const mockFetch = jest.fn<(url: any, options?: any) => Promise>(); + mockFetch.mockImplementation((url, options) => { + // Store the request details for later inspection + mockFetch.mock.calls.push([url, options]); + + // Return a mock response + return Promise.resolve({ + ok: true, + json: () => Promise.resolve({}), + }); + }); + + const model = new ChatOpenAI({ + model: "doesnt-start-with-gpt-4", + configuration: { + fetch: mockFetch, + }, + maxRetries: 0, + }); + + const modelWithTools = model.withStructuredOutput( + z.object({ + location: z.string().describe("The location to get the weather for"), + }), + { + strict: true, + } + ); + + // This will fail since we're not returning a valid response in our mocked fetch function. + await expect( + modelWithTools.invoke("What's the weather like?") + ).rejects.toThrow(); + + expect(mockFetch).toHaveBeenCalled(); + const [_url, options] = mockFetch.mock.calls[0]; + + if (options && options.body) { + const body = JSON.parse(options.body); + expect(body.tools[0].function).toHaveProperty("strict", true); + } else { + throw new Error("Body not found in request."); + } + }); + + it("Strict is NOT passed to OpenAI if NOT passed in .withStructuredOutput", async () => { + const mockFetch = jest.fn<(url: any, options?: any) => Promise>(); + mockFetch.mockImplementation((url, options) => { + // Store the request details for later inspection + mockFetch.mock.calls.push([url, options]); + + // Return a mock response + return Promise.resolve({ + ok: true, + json: () => Promise.resolve({}), + }); + }); + + const model = new ChatOpenAI({ + model: "doesnt-start-with-gpt-4", + configuration: { + fetch: mockFetch, + }, + maxRetries: 0, + }); + + const modelWithTools = model.withStructuredOutput( + z.object({ + location: z.string().describe("The location to get the weather for"), + }) + ); + + // This will fail since we're not returning a valid response in our mocked fetch function. + await expect( + modelWithTools.invoke("What's the weather like?") + ).rejects.toThrow(); + + expect(mockFetch).toHaveBeenCalled(); + const [_url, options] = mockFetch.mock.calls[0]; + + if (options && options.body) { + const body = JSON.parse(options.body); + expect(body.tools[0].function).not.toHaveProperty("strict"); + } else { + throw new Error("Body not found in request."); + } + }); }); From 2405ec5a8fbc1e700082fabff4e9c7e1db987657 Mon Sep 17 00:00:00 2001 From: bracesproul Date: Tue, 6 Aug 2024 14:32:03 -0700 Subject: [PATCH 06/13] fix --- langchain-core/src/utils/function_calling.ts | 22 ++++++- libs/langchain-openai/src/chat_models.ts | 68 +++++++++++--------- 2 files changed, 56 insertions(+), 34 deletions(-) diff --git a/langchain-core/src/utils/function_calling.ts b/langchain-core/src/utils/function_calling.ts index 11bba5c8d62b..6115da1f50c9 100644 --- a/langchain-core/src/utils/function_calling.ts +++ b/langchain-core/src/utils/function_calling.ts @@ -53,7 +53,7 @@ export function convertToOpenAITool( } ): ToolDefinition { let toolDef: ToolDefinition | undefined; - if (isStructuredTool(tool) || isRunnableToolLike(tool)) { + if (isLangChainTool(tool)) { toolDef = { type: "function", function: convertToOpenAIFunction(tool), @@ -66,7 +66,7 @@ export function convertToOpenAITool( toolDef.function.strict = fields.strict; } - return tool as ToolDefinition; + return toolDef; } /** @@ -100,3 +100,21 @@ export function isRunnableToolLike(tool?: unknown): tool is RunnableToolLike { tool.constructor.lc_name() === "RunnableToolLike" ); } + +/** + * Whether or not the tool is one of StructuredTool, RunnableTool or StructuredToolParams. + * It returns `is StructuredToolParams` since that is the most minimal interface of the three, + * while still containing the necessary properties to be passed to a LLM for tool calling. + * + * @param {unknown | undefined} tool The tool to check if it is a LangChain tool. + * @returns {tool is StructuredToolParams} Whether the inputted tool is a LangChain tool. + */ +export function isLangChainTool( + tool?: unknown +): tool is StructuredToolInterface { + return ( + isRunnableToolLike(tool) || + // eslint-disable-next-line @typescript-eslint/no-explicit-any + isStructuredTool(tool as any) + ); +} diff --git a/libs/langchain-openai/src/chat_models.ts b/libs/langchain-openai/src/chat_models.ts index c281cc16794e..8729cc30f5ba 100644 --- a/libs/langchain-openai/src/chat_models.ts +++ b/libs/langchain-openai/src/chat_models.ts @@ -27,12 +27,13 @@ import { LangSmithParams, type BaseChatModelParams, } from "@langchain/core/language_models/chat_models"; -import type { - BaseFunctionCallOptions, - BaseLanguageModelInput, - FunctionDefinition, - StructuredOutputMethodOptions, - StructuredOutputMethodParams, +import { + isOpenAITool, + type BaseFunctionCallOptions, + type BaseLanguageModelInput, + type FunctionDefinition, + type StructuredOutputMethodOptions, + type StructuredOutputMethodParams, } from "@langchain/core/language_models/base"; import { NewTokenIndices } from "@langchain/core/callbacks/base"; import { convertToOpenAITool } from "@langchain/core/utils/function_calling"; @@ -274,6 +275,25 @@ function convertMessagesToOpenAIParams(messages: BaseMessage[]) { }); } +type ChatOpenAIToolType = + | StructuredToolInterface + | OpenAIClient.ChatCompletionTool + | RunnableToolLike + // eslint-disable-next-line @typescript-eslint/no-explicit-any + | Record; + +function _convertChatOpenAIToolTypeToOpenAITool( + tool: ChatOpenAIToolType, + fields?: { + strict?: boolean; + } +): OpenAIClient.ChatCompletionTool { + if (isOpenAITool(tool)) { + return tool; + } + return convertToOpenAITool(tool, fields); +} + export interface ChatOpenAIStructuredOutputMethodOptions< IncludeRaw extends boolean > extends StructuredOutputMethodOptions { @@ -297,7 +317,7 @@ export interface ChatOpenAIStructuredOutputMethodOptions< export interface ChatOpenAICallOptions extends OpenAICallOptions, BaseFunctionCallOptions { - tools?: StructuredToolInterface[] | OpenAIClient.ChatCompletionTool[]; + tools?: ChatOpenAIToolType[]; tool_choice?: OpenAIToolChoice; promptIndex?: number; response_format?: { type: "json_object" }; @@ -612,11 +632,7 @@ export class ChatOpenAI< } override bindTools( - tools: ( - | Record - | StructuredToolInterface - | RunnableToolLike - )[], + tools: ChatOpenAIToolType[], kwargs?: Partial ): Runnable { let strict: boolean | undefined; @@ -626,7 +642,9 @@ export class ChatOpenAI< strict = this.supportsStrictToolCalling; } return this.bind({ - tools: tools.map((tool) => convertToOpenAITool(tool, { strict })), + tools: tools.map((tool) => + _convertChatOpenAIToolTypeToOpenAITool(tool, { strict }) + ), ...kwargs, } as Partial); } @@ -647,16 +665,6 @@ export class ChatOpenAI< strict = this.supportsStrictToolCalling; } - function isStructuredToolArray( - tools?: unknown[] - ): tools is StructuredToolInterface[] { - return ( - tools !== undefined && - tools.every((tool) => - Array.isArray((tool as StructuredToolInterface).lc_namespace) - ) - ); - } let streamOptionsConfig = {}; if (options?.stream_options !== undefined) { streamOptionsConfig = { stream_options: options.stream_options }; @@ -683,15 +691,11 @@ export class ChatOpenAI< stream: this.streaming, functions: options?.functions, function_call: options?.function_call, - tools: isStructuredToolArray(options?.tools) - ? options?.tools.map((tool) => convertToOpenAITool(tool, { strict })) - : options?.tools?.map((tool) => { - const toolCopy = { ...tool }; - if (strict !== undefined) { - toolCopy.function.strict = strict; - } - return toolCopy; - }), + tools: options?.tools?.length + ? options.tools.map((tool) => + _convertChatOpenAIToolTypeToOpenAITool(tool, { strict }) + ) + : undefined, tool_choice: formatToOpenAIToolChoice(options?.tool_choice), response_format: options?.response_format, seed: options?.seed, From f11f268830469bc4dc717c21adf5b22e0e3b0e8d Mon Sep 17 00:00:00 2001 From: bracesproul Date: Tue, 6 Aug 2024 14:46:33 -0700 Subject: [PATCH 07/13] fixed all tests --- libs/langchain-openai/src/chat_models.ts | 13 +++---------- .../azure/chat_models.standard.int.test.ts | 19 +++++++++++++++++-- libs/langchain-openai/src/types.ts | 5 ++--- 3 files changed, 22 insertions(+), 15 deletions(-) diff --git a/libs/langchain-openai/src/chat_models.ts b/libs/langchain-openai/src/chat_models.ts index 8729cc30f5ba..c74dcc43615c 100644 --- a/libs/langchain-openai/src/chat_models.ts +++ b/libs/langchain-openai/src/chat_models.ts @@ -350,8 +350,6 @@ export interface ChatOpenAICallOptions * * If `undefined`, `strict` argument will not be passed to the model. * - * Enabled by default for `"gpt-"` models. - * * @version 0.2.6 */ strict?: boolean; @@ -506,9 +504,7 @@ export class ChatOpenAI< /** * Whether the model supports the `strict` argument when passing in tools. - * Defaults to `true` if `modelName`/`model` starts with 'gpt-' otherwise - * defaults to `undefined`. If `undefined` the `strict` argument will not - * be passed to OpenAI. + * If `undefined` the `strict` argument will not be passed to OpenAI. */ supportsStrictToolCalling?: boolean; @@ -609,13 +605,10 @@ export class ChatOpenAI< ...fields?.configuration, }; - // Assume only "gpt-..." models support strict tool calling as of 08/06/24. - // If `supportsStrictToolCalling` is explicitly set, use that value, or `true` - // if the model name starts with "gpt-". Else leave undefined so it's not passed to OpenAI. + // If `supportsStrictToolCalling` is explicitly set, use that value. + // Else leave undefined so it's not passed to OpenAI. if (fields?.supportsStrictToolCalling !== undefined) { this.supportsStrictToolCalling = fields.supportsStrictToolCalling; - } else if (this.modelName.startsWith("gpt-")) { - this.supportsStrictToolCalling = true; } } diff --git a/libs/langchain-openai/src/tests/azure/chat_models.standard.int.test.ts b/libs/langchain-openai/src/tests/azure/chat_models.standard.int.test.ts index 8146f04d0f88..64052685d6c2 100644 --- a/libs/langchain-openai/src/tests/azure/chat_models.standard.int.test.ts +++ b/libs/langchain-openai/src/tests/azure/chat_models.standard.int.test.ts @@ -1,17 +1,25 @@ /* eslint-disable no-process-env */ -import { test, expect } from "@jest/globals"; +import { test, expect, beforeAll, afterAll } from "@jest/globals"; import { ChatModelIntegrationTests } from "@langchain/standard-tests"; import { AIMessageChunk } from "@langchain/core/messages"; import { AzureChatOpenAI } from "../../azure/chat_models.js"; import { ChatOpenAICallOptions } from "../../chat_models.js"; +let openAIAPIKey: string | undefined; + beforeAll(() => { + if (process.env.OPENAI_API_KEY) { + openAIAPIKey = process.env.OPENAI_API_KEY; + process.env.OPENAI_API_KEY = ""; + } + if (!process.env.AZURE_OPENAI_API_KEY) { process.env.AZURE_OPENAI_API_KEY = process.env.TEST_AZURE_OPENAI_API_KEY; } if (!process.env.AZURE_OPENAI_API_DEPLOYMENT_NAME) { process.env.AZURE_OPENAI_API_DEPLOYMENT_NAME = - process.env.TEST_AZURE_OPENAI_API_DEPLOYMENT_NAME; + process.env.TEST_AZURE_OPENAI_API_DEPLOYMENT_NAME ?? + process.env.AZURE_OPENAI_CHAT_DEPLOYMENT_NAME; } if (!process.env.AZURE_OPENAI_BASE_PATH) { process.env.AZURE_OPENAI_BASE_PATH = @@ -23,6 +31,12 @@ beforeAll(() => { } }); +afterAll(() => { + if (openAIAPIKey) { + process.env.OPENAI_API_KEY = openAIAPIKey; + } +}); + class AzureChatOpenAIStandardIntegrationTests extends ChatModelIntegrationTests< ChatOpenAICallOptions, AIMessageChunk @@ -35,6 +49,7 @@ class AzureChatOpenAIStandardIntegrationTests extends ChatModelIntegrationTests< supportsParallelToolCalls: true, constructorArgs: { model: "gpt-3.5-turbo", + maxRetries: 0, }, }); } diff --git a/libs/langchain-openai/src/types.ts b/libs/langchain-openai/src/types.ts index afd9fea2624b..0d93089619e2 100644 --- a/libs/langchain-openai/src/types.ts +++ b/libs/langchain-openai/src/types.ts @@ -157,9 +157,8 @@ export interface OpenAIChatInput extends OpenAIBaseInput { __includeRawResponse?: boolean; /** - * Whether the model supports the 'strict' argument when passing in tools. - * Defaults to `true` if `modelName`/`model` starts with 'gpt-' otherwise - * defaults to `false`. + * Whether the model supports the `strict` argument when passing in tools. + * If `undefined` the `strict` argument will not be passed to OpenAI. */ supportsStrictToolCalling?: boolean; } From 643e050a766ee011079821435f4a7e1a2961e429 Mon Sep 17 00:00:00 2001 From: bracesproul Date: Tue, 6 Aug 2024 14:55:07 -0700 Subject: [PATCH 08/13] docs --- .../docs/integrations/chat/openai.ipynb | 78 ++++++++++++++++++- 1 file changed, 76 insertions(+), 2 deletions(-) diff --git a/docs/core_docs/docs/integrations/chat/openai.ipynb b/docs/core_docs/docs/integrations/chat/openai.ipynb index 8dbf079948d8..8dc1cf97f640 100644 --- a/docs/core_docs/docs/integrations/chat/openai.ipynb +++ b/docs/core_docs/docs/integrations/chat/openai.ipynb @@ -411,7 +411,7 @@ }, { "cell_type": "markdown", - "id": "3a5bb5ca-c3ae-4a58-be67-2cd18574b9a3", + "id": "bc5ecebd", "metadata": {}, "source": [ "## Tool calling\n", @@ -420,8 +420,82 @@ "\n", "- [How to: disable parallel tool calling](/docs/how_to/tool_calling_parallel/)\n", "- [How to: force a tool call](/docs/how_to/tool_choice/)\n", - "- [How to: bind model-specific tool formats to a model](/docs/how_to/tool_calling#binding-model-specific-formats-advanced).\n", + "- [How to: bind model-specific tool formats to a model](/docs/how_to/tool_calling#binding-model-specific-formats-advanced)." + ] + }, + { + "cell_type": "markdown", + "id": "3392390e", + "metadata": {}, + "source": [ + "### ``strict: true``\n", + "\n", + "```{=mdx}\n", + "\n", + ":::info Requires ``@langchain/openai >= 0.2.6``\n", + "\n", + "As of Aug 6, 2024, OpenAI supports a `strict` argument when calling tools that will enforce that the tool argument schema is respected by the model. See more here: https://platform.openai.com/docs/guides/function-calling\n", + "\n", + "**Note**: If ``strict: true`` the tool definition will also be validated, and a subset of JSON schema are accepted. Crucially, schema cannot have optional args (those with default values). Read the full docs on what types of schema are supported here: https://platform.openai.com/docs/guides/structured-outputs/supported-schemas. \n", + ":::\n", + "\n", + "\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "90f0d465", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[\n", + " {\n", + " name: 'get_current_weather',\n", + " args: { location: 'Hanoi' },\n", + " type: 'tool_call',\n", + " id: 'call_aB85ybkLCoccpzqHquuJGH3d'\n", + " }\n", + "]\n" + ] + } + ], + "source": [ + "import { ChatOpenAI } from \"@langchain/openai\";\n", + "import { tool } from \"@langchain/core/tools\";\n", + "import { z } from \"zod\";\n", + "\n", + "const weatherTool = tool((_) => \"no-op\", {\n", + " name: \"get_current_weather\",\n", + " description: \"Get the current weather\",\n", + " schema: z.object({\n", + " location: z.string(),\n", + " }),\n", + "})\n", + "\n", + "const llmWithStrictTrue = new ChatOpenAI({\n", + " model: \"gpt-4o\",\n", + "}).bindTools([weatherTool], {\n", + " strict: true,\n", + " tool_choice: weatherTool.name,\n", + "});\n", + "\n", + "// Although the question is not about the weather, it will call the tool with the correct arguments\n", + "// because we passed `tool_choice` and `strict: true`.\n", + "const strictTrueResult = await llmWithStrictTrue.invoke(\"What is 127862 times 12898 divided by 2?\");\n", "\n", + "console.dir(strictTrueResult.tool_calls, { depth: null });" + ] + }, + { + "cell_type": "markdown", + "id": "3a5bb5ca-c3ae-4a58-be67-2cd18574b9a3", + "metadata": {}, + "source": [ "## API reference\n", "\n", "For detailed documentation of all ChatOpenAI features and configurations head to the API reference: https://api.js.langchain.com/classes/langchain_openai.ChatOpenAI.html" From 04365e406ee45d81c8ad3ceb839397dc74d1f351 Mon Sep 17 00:00:00 2001 From: bracesproul Date: Tue, 6 Aug 2024 14:58:06 -0700 Subject: [PATCH 09/13] fix build errors --- langchain/src/agents/openai_tools/index.ts | 2 +- langchain/src/agents/openai_tools/output_parser.ts | 2 +- libs/langchain-groq/src/chat_models.ts | 2 +- libs/langchain-ollama/src/chat_models.ts | 4 ++-- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/langchain/src/agents/openai_tools/index.ts b/langchain/src/agents/openai_tools/index.ts index ae071993224e..8e6fe6f61927 100644 --- a/langchain/src/agents/openai_tools/index.ts +++ b/langchain/src/agents/openai_tools/index.ts @@ -116,7 +116,7 @@ export async function createOpenAIToolsAgent({ ].join("\n") ); } - const modelWithTools = llm.bind({ tools: tools.map(convertToOpenAITool) }); + const modelWithTools = llm.bind({ tools: tools.map((tool) => convertToOpenAITool(tool)) }); const agent = AgentRunnableSequence.fromRunnables( [ RunnablePassthrough.assign({ diff --git a/langchain/src/agents/openai_tools/output_parser.ts b/langchain/src/agents/openai_tools/output_parser.ts index dbaa15d8ad27..c18d6a1ff2ab 100644 --- a/langchain/src/agents/openai_tools/output_parser.ts +++ b/langchain/src/agents/openai_tools/output_parser.ts @@ -30,7 +30,7 @@ export type { ToolsAgentAction, ToolsAgentStep }; * new ChatOpenAI({ * modelName: "gpt-3.5-turbo-1106", * temperature: 0, - * }).bind({ tools: tools.map(convertToOpenAITool) }), + * }).bind({ tools: tools.map((tool) => convertToOpenAITool(tool)) }), * new OpenAIToolsAgentOutputParser(), * ]).withConfig({ runName: "OpenAIToolsAgent" }); * diff --git a/libs/langchain-groq/src/chat_models.ts b/libs/langchain-groq/src/chat_models.ts index b2291dc552ce..413e8803fdff 100644 --- a/libs/langchain-groq/src/chat_models.ts +++ b/libs/langchain-groq/src/chat_models.ts @@ -437,7 +437,7 @@ export class ChatGroq extends BaseChatModel< kwargs?: Partial ): Runnable { return this.bind({ - tools: tools.map(convertToOpenAITool), + tools: tools.map((tool) => convertToOpenAITool(tool)), ...kwargs, }); } diff --git a/libs/langchain-ollama/src/chat_models.ts b/libs/langchain-ollama/src/chat_models.ts index 9f70a9e0e0b0..d08471ab0506 100644 --- a/libs/langchain-ollama/src/chat_models.ts +++ b/libs/langchain-ollama/src/chat_models.ts @@ -298,7 +298,7 @@ export class ChatOllama kwargs?: Partial ): Runnable { return this.bind({ - tools: tools.map(convertToOpenAITool), + tools: tools.map((tool) => convertToOpenAITool(tool)), ...kwargs, }); } @@ -359,7 +359,7 @@ export class ChatOllama stop: options?.stop, }, tools: options?.tools?.length - ? (options.tools.map(convertToOpenAITool) as OllamaTool[]) + ? (options.tools.map((tool) => convertToOpenAITool(tool)) as OllamaTool[]) : undefined, }; } From d64151e6646bd4d29e4c99510e08e9e985a80777 Mon Sep 17 00:00:00 2001 From: bracesproul Date: Tue, 6 Aug 2024 15:06:10 -0700 Subject: [PATCH 10/13] fix more type errors --- langchain-core/src/utils/function_calling.ts | 44 ++++++++++++-------- 1 file changed, 27 insertions(+), 17 deletions(-) diff --git a/langchain-core/src/utils/function_calling.ts b/langchain-core/src/utils/function_calling.ts index 6115da1f50c9..38a976f75d7b 100644 --- a/langchain-core/src/utils/function_calling.ts +++ b/langchain-core/src/utils/function_calling.ts @@ -14,20 +14,25 @@ import { Runnable, RunnableToolLike } from "../runnables/base.js"; */ export function convertToOpenAIFunction( tool: StructuredToolInterface | RunnableToolLike, - fields?: { - /** - * If `true`, model output is guaranteed to exactly match the JSON Schema - * provided in the function definition. - */ - strict?: boolean; - } + fields?: + | { + /** + * If `true`, model output is guaranteed to exactly match the JSON Schema + * provided in the function definition. + */ + strict?: boolean; + } + | number ): FunctionDefinition { + // @TODO 0.3.0 Remove the `number` typing + const fieldsCopy = typeof fields === "number" ? undefined : fields; + return { name: tool.name, description: tool.description, parameters: zodToJsonSchema(tool.schema), // Do not include the `strict` field if it is `undefined`. - ...(fields?.strict !== undefined ? { strict: fields.strict } : {}), + ...(fieldsCopy?.strict !== undefined ? { strict: fieldsCopy.strict } : {}), }; } @@ -44,14 +49,19 @@ export function convertToOpenAIFunction( export function convertToOpenAITool( // eslint-disable-next-line @typescript-eslint/no-explicit-any tool: StructuredToolInterface | Record | RunnableToolLike, - fields?: { - /** - * If `true`, model output is guaranteed to exactly match the JSON Schema - * provided in the function definition. - */ - strict?: boolean; - } + fields?: + | { + /** + * If `true`, model output is guaranteed to exactly match the JSON Schema + * provided in the function definition. + */ + strict?: boolean; + } + | number ): ToolDefinition { + // @TODO 0.3.0 Remove the `number` typing + const fieldsCopy = typeof fields === "number" ? undefined : fields; + let toolDef: ToolDefinition | undefined; if (isLangChainTool(tool)) { toolDef = { @@ -62,8 +72,8 @@ export function convertToOpenAITool( toolDef = tool as ToolDefinition; } - if (fields?.strict !== undefined) { - toolDef.function.strict = fields.strict; + if (fieldsCopy?.strict !== undefined) { + toolDef.function.strict = fieldsCopy.strict; } return toolDef; From e11697c6b4cfc1d7d2f9acba1da22baaa803b8c2 Mon Sep 17 00:00:00 2001 From: bracesproul Date: Tue, 6 Aug 2024 15:17:25 -0700 Subject: [PATCH 11/13] cr --- langchain/src/agents/openai_tools/index.ts | 4 +++- libs/langchain-ollama/src/chat_models.ts | 4 +++- libs/langchain-openai/src/tests/chat_models.test.ts | 7 +++++++ 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/langchain/src/agents/openai_tools/index.ts b/langchain/src/agents/openai_tools/index.ts index 8e6fe6f61927..fe13da61f844 100644 --- a/langchain/src/agents/openai_tools/index.ts +++ b/langchain/src/agents/openai_tools/index.ts @@ -116,7 +116,9 @@ export async function createOpenAIToolsAgent({ ].join("\n") ); } - const modelWithTools = llm.bind({ tools: tools.map((tool) => convertToOpenAITool(tool)) }); + const modelWithTools = llm.bind({ + tools: tools.map((tool) => convertToOpenAITool(tool)), + }); const agent = AgentRunnableSequence.fromRunnables( [ RunnablePassthrough.assign({ diff --git a/libs/langchain-ollama/src/chat_models.ts b/libs/langchain-ollama/src/chat_models.ts index d08471ab0506..15c7ca31897e 100644 --- a/libs/langchain-ollama/src/chat_models.ts +++ b/libs/langchain-ollama/src/chat_models.ts @@ -359,7 +359,9 @@ export class ChatOllama stop: options?.stop, }, tools: options?.tools?.length - ? (options.tools.map((tool) => convertToOpenAITool(tool)) as OllamaTool[]) + ? (options.tools.map((tool) => + convertToOpenAITool(tool) + ) as OllamaTool[]) : undefined, }; } diff --git a/libs/langchain-openai/src/tests/chat_models.test.ts b/libs/langchain-openai/src/tests/chat_models.test.ts index a2ca273c5c65..2ba730a07242 100644 --- a/libs/langchain-openai/src/tests/chat_models.test.ts +++ b/libs/langchain-openai/src/tests/chat_models.test.ts @@ -49,6 +49,7 @@ describe("strict tool calling", () => { const model = new ChatOpenAI({ model: "gpt-4", + apiKey: "test-key", configuration: { fetch: mockFetch, }, @@ -96,6 +97,7 @@ describe("strict tool calling", () => { const model = new ChatOpenAI({ model: "gpt-4", + apiKey: "test-key", configuration: { fetch: mockFetch, }, @@ -146,6 +148,7 @@ describe("strict tool calling", () => { const model = new ChatOpenAI({ model: "gpt-4", + apiKey: "test-key", configuration: { fetch: mockFetch, }, @@ -194,6 +197,7 @@ describe("strict tool calling", () => { const model = new ChatOpenAI({ model: "gpt-4", + apiKey: "test-key", configuration: { fetch: mockFetch, }, @@ -244,6 +248,7 @@ describe("strict tool calling", () => { const model = new ChatOpenAI({ model: "doesnt-start-with-gpt-4", + apiKey: "test-key", configuration: { fetch: mockFetch, }, @@ -283,6 +288,7 @@ describe("strict tool calling", () => { const model = new ChatOpenAI({ model: "doesnt-start-with-gpt-4", + apiKey: "test-key", configuration: { fetch: mockFetch, }, @@ -329,6 +335,7 @@ describe("strict tool calling", () => { const model = new ChatOpenAI({ model: "doesnt-start-with-gpt-4", + apiKey: "test-key", configuration: { fetch: mockFetch, }, From 2a2999477256d6c8efb7f95d773e86544f81fd8c Mon Sep 17 00:00:00 2001 From: bracesproul Date: Tue, 6 Aug 2024 15:28:44 -0700 Subject: [PATCH 12/13] cr --- libs/langchain-openai/src/chat_models.ts | 10 ++ .../src/tests/chat_models.test.ts | 124 +----------------- 2 files changed, 14 insertions(+), 120 deletions(-) diff --git a/libs/langchain-openai/src/chat_models.ts b/libs/langchain-openai/src/chat_models.ts index c74dcc43615c..80ebccb2c0fc 100644 --- a/libs/langchain-openai/src/chat_models.ts +++ b/libs/langchain-openai/src/chat_models.ts @@ -289,6 +289,16 @@ function _convertChatOpenAIToolTypeToOpenAITool( } ): OpenAIClient.ChatCompletionTool { if (isOpenAITool(tool)) { + if (fields?.strict !== undefined) { + return { + ...tool, + function: { + ...tool.function, + strict: fields.strict, + } + } + } + return tool; } return convertToOpenAITool(tool, fields); diff --git a/libs/langchain-openai/src/tests/chat_models.test.ts b/libs/langchain-openai/src/tests/chat_models.test.ts index 2ba730a07242..ae15247a6156 100644 --- a/libs/langchain-openai/src/tests/chat_models.test.ts +++ b/libs/langchain-openai/src/tests/chat_models.test.ts @@ -67,16 +67,7 @@ describe("strict tool calling", () => { const [_url, options] = mockFetch.mock.calls[0]; if (options && options.body) { - expect(JSON.parse(options.body).tools).toEqual([ - expect.objectContaining({ - type: "function", - function: { - ...weatherTool.function, - // This should be added to the function call because `strict` was passed to `bindTools` - strict: true, - }, - }), - ]); + expect(JSON.parse(options.body).tools[0].function).toHaveProperty("strict", true); } else { throw new Error("Body not found in request."); } @@ -118,65 +109,7 @@ describe("strict tool calling", () => { const [_url, options] = mockFetch.mock.calls[0]; if (options && options.body) { - expect(JSON.parse(options.body).tools).toEqual([ - expect.objectContaining({ - type: "function", - function: { - ...weatherTool.function, - // This should be added to the function call because `strict` was passed to `bind` - strict: true, - }, - }), - ]); - } else { - throw new Error("Body not found in request."); - } - }); - - it("Sets strict to true if the model name starts with 'gpt-'", async () => { - const mockFetch = jest.fn<(url: any, options?: any) => Promise>(); - mockFetch.mockImplementation((url, options) => { - // Store the request details for later inspection - mockFetch.mock.calls.push([url, options]); - - // Return a mock response - return Promise.resolve({ - ok: true, - json: () => Promise.resolve({}), - }); - }); - - const model = new ChatOpenAI({ - model: "gpt-4", - apiKey: "test-key", - configuration: { - fetch: mockFetch, - }, - maxRetries: 0, - }); - - // Do NOT pass `strict` here since we're checking that it's set to true by default - const modelWithTools = model.bindTools([weatherTool]); - - // This will fail since we're not returning a valid response in our mocked fetch function. - await expect( - modelWithTools.invoke("What's the weather like?") - ).rejects.toThrow(); - - expect(mockFetch).toHaveBeenCalled(); - const [_url, options] = mockFetch.mock.calls[0]; - - if (options && options.body) { - expect(JSON.parse(options.body).tools).toEqual([ - expect.objectContaining({ - type: "function", - function: { - ...weatherTool.function, - // This should be added to the function call because `strict` was passed to `bind` - strict: true, - }, - }), - ]); + expect(JSON.parse(options.body).tools[0].function).toHaveProperty("strict", true); } else { throw new Error("Body not found in request."); } @@ -217,57 +150,7 @@ describe("strict tool calling", () => { const [_url, options] = mockFetch.mock.calls[0]; if (options && options.body) { - expect(JSON.parse(options.body).tools).toEqual([ - expect.objectContaining({ - type: "function", - function: { - ...weatherTool.function, - // This should be added to the function call because `strict` was passed to `bind` - strict: false, - }, - }), - ]); - } else { - throw new Error("Body not found in request."); - } - }); - - // test fails unless it's run in isolation - it.skip("Strict is not passed if non 'gpt-' model is passed.", async () => { - const mockFetch = jest.fn<(url: any, options?: any) => Promise>(); - mockFetch.mockImplementation((url, options) => { - // Store the request details for later inspection - mockFetch.mock.calls.push([url, options]); - - // Return a mock response - return Promise.resolve({ - ok: true, - json: () => Promise.resolve({}), - }); - }); - - const model = new ChatOpenAI({ - model: "doesnt-start-with-gpt-4", - apiKey: "test-key", - configuration: { - fetch: mockFetch, - }, - maxRetries: 0, - }); - - const modelWithTools = model.bindTools([weatherTool]); - - // This will fail since we're not returning a valid response in our mocked fetch function. - await expect( - modelWithTools.invoke("What's the weather like?") - ).rejects.toThrow(); - - expect(mockFetch).toHaveBeenCalled(); - const [_url, options] = mockFetch.mock.calls[0]; - - if (options && options.body) { - const body = JSON.parse(options.body); - expect(body.tools[0].function).not.toHaveProperty("strict"); + expect(JSON.parse(options.body).tools[0].function).toHaveProperty("strict", false); } else { throw new Error("Body not found in request."); } @@ -293,6 +176,7 @@ describe("strict tool calling", () => { fetch: mockFetch, }, maxRetries: 0, + supportsStrictToolCalling: true, }); const modelWithTools = model.withStructuredOutput( From 7bd1f8143065d8dfbce4721a270b209ce06bfa98 Mon Sep 17 00:00:00 2001 From: bracesproul Date: Tue, 6 Aug 2024 15:28:53 -0700 Subject: [PATCH 13/13] chore: lint files --- libs/langchain-openai/src/chat_models.ts | 4 ++-- .../src/tests/chat_models.test.ts | 15 ++++++++++++--- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/libs/langchain-openai/src/chat_models.ts b/libs/langchain-openai/src/chat_models.ts index 80ebccb2c0fc..52d2eeb52bf4 100644 --- a/libs/langchain-openai/src/chat_models.ts +++ b/libs/langchain-openai/src/chat_models.ts @@ -295,8 +295,8 @@ function _convertChatOpenAIToolTypeToOpenAITool( function: { ...tool.function, strict: fields.strict, - } - } + }, + }; } return tool; diff --git a/libs/langchain-openai/src/tests/chat_models.test.ts b/libs/langchain-openai/src/tests/chat_models.test.ts index ae15247a6156..a24c180ff1d0 100644 --- a/libs/langchain-openai/src/tests/chat_models.test.ts +++ b/libs/langchain-openai/src/tests/chat_models.test.ts @@ -67,7 +67,10 @@ describe("strict tool calling", () => { const [_url, options] = mockFetch.mock.calls[0]; if (options && options.body) { - expect(JSON.parse(options.body).tools[0].function).toHaveProperty("strict", true); + expect(JSON.parse(options.body).tools[0].function).toHaveProperty( + "strict", + true + ); } else { throw new Error("Body not found in request."); } @@ -109,7 +112,10 @@ describe("strict tool calling", () => { const [_url, options] = mockFetch.mock.calls[0]; if (options && options.body) { - expect(JSON.parse(options.body).tools[0].function).toHaveProperty("strict", true); + expect(JSON.parse(options.body).tools[0].function).toHaveProperty( + "strict", + true + ); } else { throw new Error("Body not found in request."); } @@ -150,7 +156,10 @@ describe("strict tool calling", () => { const [_url, options] = mockFetch.mock.calls[0]; if (options && options.body) { - expect(JSON.parse(options.body).tools[0].function).toHaveProperty("strict", false); + expect(JSON.parse(options.body).tools[0].function).toHaveProperty( + "strict", + false + ); } else { throw new Error("Body not found in request."); }