From b7557524c363cd8fc55f5126fc25fea72406dfdb Mon Sep 17 00:00:00 2001 From: bracesproul Date: Wed, 12 Jun 2024 14:12:27 -0700 Subject: [PATCH] add standard test for openai formatted tools --- .../tests/chattogetherai.standard.test.ts | 4 +- libs/langchain-standard-tests/package.json | 3 +- .../src/integration_tests/chat_models.ts | 39 +++++++++++++++++++ yarn.lock | 10 +++++ 4 files changed, 53 insertions(+), 3 deletions(-) diff --git a/libs/langchain-community/src/chat_models/tests/chattogetherai.standard.test.ts b/libs/langchain-community/src/chat_models/tests/chattogetherai.standard.test.ts index 2fc8d1bfb71e..c499b866c561 100644 --- a/libs/langchain-community/src/chat_models/tests/chattogetherai.standard.test.ts +++ b/libs/langchain-community/src/chat_models/tests/chattogetherai.standard.test.ts @@ -11,8 +11,8 @@ class ChatTogetherAIStandardUnitTests extends ChatModelUnitTests< constructor() { super({ Cls: ChatTogetherAI, - chatModelHasToolCalling: false, - chatModelHasStructuredOutput: false, + chatModelHasToolCalling: true, + chatModelHasStructuredOutput: true, constructorArgs: {}, }); process.env.TOGETHER_AI_API_KEY = "test"; diff --git a/libs/langchain-standard-tests/package.json b/libs/langchain-standard-tests/package.json index 794199bf70be..b40701745743 100644 --- a/libs/langchain-standard-tests/package.json +++ b/libs/langchain-standard-tests/package.json @@ -34,7 +34,8 @@ "dependencies": { "@jest/globals": "^29.5.0", "@langchain/core": "workspace:*", - "zod": "^3.22.4" + "zod": "^3.22.4", + "zod-to-json-schema": "^3.23.0" }, "devDependencies": { "@langchain/scripts": "workspace:*", diff --git a/libs/langchain-standard-tests/src/integration_tests/chat_models.ts b/libs/langchain-standard-tests/src/integration_tests/chat_models.ts index d8416b743aea..bf30ca1d4519 100644 --- a/libs/langchain-standard-tests/src/integration_tests/chat_models.ts +++ b/libs/langchain-standard-tests/src/integration_tests/chat_models.ts @@ -15,6 +15,7 @@ import { BaseChatModelsTestsFields, RecordStringAny, } from "../base.js"; +import { zodToJsonSchema } from "zod-to-json-schema"; const adderSchema = /* #__PURE__ */ z .object({ @@ -385,6 +386,37 @@ export abstract class ChatModelIntegrationTests< expect([1, 2].includes(resultStringContent.parsed.b)).toBeTruthy(); } + async testBindToolsWithOpenAIFormattedTools() { + if (!this.chatModelHasToolCalling) { + console.log("Test requires tool calling. Skipping..."); + return; + } + + const model = new this.Cls(this.constructorArgs); + if (!model.bindTools) { + throw new Error( + "bindTools undefined. Cannot test OpenAI formatted tool calls." + ); + } + const modelWithTools = model.bindTools([{ + type: "function", + function: { + name: "math_addition", + description: adderSchema.description, + parameters: zodToJsonSchema(adderSchema) + } + }]); + + const result: AIMessage = await modelWithTools.invoke("What is 1 + 2"); + expect(result).toBeInstanceOf(AIMessage); + expect(result.tool_calls).toHaveLength(1); + if (!result.tool_calls) { + throw new Error("result.tool_calls is undefined"); + } + const { tool_calls } = result; + expect(tool_calls[0].name).toBe("math_addition"); + } + /** * Run all unit tests for the chat model. * Each test is wrapped in a try/catch block to prevent the entire test suite from failing. @@ -471,6 +503,13 @@ export abstract class ChatModelIntegrationTests< console.error("testWithStructuredOutputIncludeRaw failed", e); } + try { + await this.testBindToolsWithOpenAIFormattedTools(); + } catch (e: any) { + allTestsPassed = false; + console.error("testBindToolsWithOpenAIFormattedTools failed", e); + } + return allTestsPassed; } } diff --git a/yarn.lock b/yarn.lock index f1fe833cb646..7b2e4d5df0ee 100644 --- a/yarn.lock +++ b/yarn.lock @@ -10258,6 +10258,7 @@ __metadata: ts-jest: ^29.1.0 typescript: ^5.4.5 zod: ^3.22.4 + zod-to-json-schema: ^3.23.0 languageName: unknown linkType: soft @@ -38392,6 +38393,15 @@ __metadata: languageName: node linkType: hard +"zod-to-json-schema@npm:^3.23.0": + version: 3.23.0 + resolution: "zod-to-json-schema@npm:3.23.0" + peerDependencies: + zod: ^3.23.3 + checksum: 56f220f06687b41602478cf19f9fbf04488a450c0e47e6cd6c1dc3b6729e2b1c75f742a52a16cbb11bcdf1ff7b2bf2043dfff59f3784d6ac8ecfa562ce035e21 + languageName: node + linkType: hard + "zod@npm:^3.22.3, zod@npm:^3.22.4": version: 3.22.4 resolution: "zod@npm:3.22.4"