Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

standard-tests[minor]: Add tests for parallel tool calls #6258

Merged
merged 9 commits into from
Jul 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,24 @@ class ChatAnthropicStandardIntegrationTests extends ChatModelIntegrationTests<
Cls: ChatAnthropic,
chatModelHasToolCalling: true,
chatModelHasStructuredOutput: true,
supportsParallelToolCalls: true,
constructorArgs: {
model: "claude-3-haiku-20240307",
},
});
}

async testParallelToolCalling() {
// Override constructor args to use a better model for this test.
// I found that haiku struggles with parallel tool calling.
const constructorArgsCopy = { ...this.constructorArgs };
this.constructorArgs = {
...this.constructorArgs,
model: "claude-3-5-sonnet-20240620",
};
await super.testParallelToolCalling();
this.constructorArgs = constructorArgsCopy;
}
}

const testClass = new ChatAnthropicStandardIntegrationTests();
Expand Down
7 changes: 7 additions & 0 deletions libs/langchain-aws/src/tests/chat_models.standard.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class ChatBedrockConverseStandardIntegrationTests extends ChatModelIntegrationTe
Cls: ChatBedrockConverse,
chatModelHasToolCalling: true,
chatModelHasStructuredOutput: true,
supportsParallelToolCalls: true,
constructorArgs: {
region,
model: "anthropic.claude-3-sonnet-20240229-v1:0",
Expand Down Expand Up @@ -51,6 +52,12 @@ class ChatBedrockConverseStandardIntegrationTests extends ChatModelIntegrationTe
"Not properly implemented."
);
}

async testParallelToolCalling() {
// Pass `true` in the second argument to only verify it can support parallel tool calls in the message history.
// This is because the model struggles to actually call parallel tools.
await super.testParallelToolCalling(undefined, true);
}
}

const testClass = new ChatBedrockConverseStandardIntegrationTests();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ class ChatGoogleGenerativeAIStandardIntegrationTests extends ChatModelIntegratio
Cls: ChatGoogleGenerativeAI,
chatModelHasToolCalling: true,
chatModelHasStructuredOutput: true,
supportsParallelToolCalls: true,
constructorArgs: {
maxRetries: 1,
model: "gemini-1.5-pro",
},
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class ChatVertexAIStandardIntegrationTests extends ChatModelIntegrationTests<
Cls: ChatVertexAI,
chatModelHasToolCalling: true,
chatModelHasStructuredOutput: true,
supportsParallelToolCalls: true,
invokeResponseType: AIMessageChunk,
constructorArgs: {
model: "gemini-1.5-pro",
Expand All @@ -42,6 +43,12 @@ class ChatVertexAIStandardIntegrationTests extends ChatModelIntegrationTests<
"Google VertexAI only supports objects in schemas when the parameters are defined."
);
}

async testParallelToolCalling() {
// Pass `true` in the second argument to only verify it can support parallel tool calls in the message history.
// This is because the model struggles to actually call parallel tools.
await super.testParallelToolCalling(undefined, true);
}
}

const testClass = new ChatVertexAIStandardIntegrationTests();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class AzureChatOpenAIStandardIntegrationTests extends ChatModelIntegrationTests<
Cls: AzureChatOpenAI,
chatModelHasToolCalling: true,
chatModelHasStructuredOutput: true,
supportsParallelToolCalls: true,
constructorArgs: {
model: "gpt-3.5-turbo",
},
Expand Down Expand Up @@ -62,6 +63,12 @@ class AzureChatOpenAIStandardIntegrationTests extends ChatModelIntegrationTests<
"AzureChatOpenAI only supports objects in schemas when the parameters are defined."
);
}

async testParallelToolCalling() {
// Pass `true` in the second argument to only verify it can support parallel tool calls in the message history.
// This is because the model struggles to actually call parallel tools.
await super.testParallelToolCalling(undefined, true);
}
}

const testClass = new AzureChatOpenAIStandardIntegrationTests();
Expand Down
13 changes: 13 additions & 0 deletions libs/langchain-openai/src/tests/chat_models.standard.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class ChatOpenAIStandardIntegrationTests extends ChatModelIntegrationTests<
Cls: ChatOpenAI,
chatModelHasToolCalling: true,
chatModelHasStructuredOutput: true,
supportsParallelToolCalls: true,
constructorArgs: {
model: "gpt-3.5-turbo",
},
Expand All @@ -44,6 +45,18 @@ class ChatOpenAIStandardIntegrationTests extends ChatModelIntegrationTests<
"\nOpenAI only supports objects in schemas when the parameters are defined."
);
}

async testParallelToolCalling() {
// Override constructor args to use a better model for this test.
// I found that GPT 3.5 struggles with parallel tool calling.
const constructorArgsCopy = { ...this.constructorArgs };
this.constructorArgs = {
...this.constructorArgs,
model: "gpt-4o",
};
await super.testParallelToolCalling();
this.constructorArgs = constructorArgsCopy;
}
}

const testClass = new ChatOpenAIStandardIntegrationTests();
Expand Down
214 changes: 214 additions & 0 deletions libs/langchain-standard-tests/src/integration_tests/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@
* @default "abc123"
*/
functionId?: string;
/**
* Whether or not the model supports parallel tool calling.
* @default false
*/
supportsParallelToolCalls?: boolean;
}

export abstract class ChatModelIntegrationTests<
Expand All @@ -82,6 +87,8 @@

invokeResponseType: typeof AIMessage | typeof AIMessageChunk = AIMessage;

supportsParallelToolCalls = false;

constructor(
fields: ChatModelIntegrationTestsFields<
CallOptions,
Expand All @@ -93,6 +100,8 @@
this.functionId = fields.functionId ?? this.functionId;
this.invokeResponseType =
fields.invokeResponseType ?? this.invokeResponseType;
this.supportsParallelToolCalls =
fields.supportsParallelToolCalls ?? this.supportsParallelToolCalls;
}

/**
Expand Down Expand Up @@ -939,7 +948,7 @@
],
});
const prompt = getBufferString([humanMessage]);
const llmKey = model._getSerializedCacheKeyParametersForCall({} as any);

Check warning on line 951 in libs/langchain-standard-tests/src/integration_tests/chat_models.ts

View workflow job for this annotation

GitHub Actions / Check linting

Unexpected any. Specify a different type

// Invoke the model to trigger a cache update
await model.invoke([humanMessage], callOptions);
Expand Down Expand Up @@ -1313,6 +1322,204 @@
expect(typeof result.apiDetails === "object").toBeTruthy();
}

/**
* Tests the chat model's ability to handle parallel tool calls in various scenarios.
* This comprehensive test covers three aspects of parallel tool calling:
* 1. Invoking multiple tools simultaneously
* 2. Streaming responses with parallel tool calls
* 3. Processing message histories containing parallel tool calls
*
* The test uses a weather tool and a current time tool to simulate complex, multi-tool scenarios.
* It ensures that the model can correctly process and respond to prompts requiring multiple tool calls,
* both in streaming and non-streaming contexts, and can handle message histories with parallel tool calls.
*
* @param {InstanceType<this["Cls"]>["ParsedCallOptions"] | undefined} callOptions Optional call options to pass to the model.
* @param {boolean} onlyVerifyHistory If true, only verifies the message history test.
*/
async testParallelToolCalling(
callOptions?: InstanceType<this["Cls"]>["ParsedCallOptions"],
onlyVerifyHistory = false
) {
// Skip the test if the model doesn't support tool calling
if (!this.chatModelHasToolCalling) {
console.log("Test requires tool calling. Skipping...");
return;
}
// Skip the test if the model doesn't support parallel tool calls
if (!this.supportsParallelToolCalls) {
console.log("Test requires parallel tool calls. Skipping...");
return;
}
const model = new this.Cls(this.constructorArgs);
if (!model.bindTools) {
throw new Error(
"bindTools undefined. Cannot test OpenAI formatted tool calls."
);
}

const weatherTool = tool((_) => "no-op", {
name: "get_current_weather",
description: "Get the current weather in a given location",
schema: z.object({
location: z.string().describe("The city name, e.g. San Francisco"),
}),
});
const currentTimeTool = tool((_) => "no-op", {
name: "get_current_time",
description: "Get the current time in a given location",
schema: z.object({
location: z.string().describe("The city name, e.g. San Francisco"),
}),
});

const modelWithTools = model.bindTools([weatherTool, currentTimeTool]);

const callParallelToolsPrompt =
"What's the weather and current time in San Francisco?\n" +
"Ensure you ALWAYS call the 'get_current_weather' tool for weather and 'get_current_time' tool for time.";

// Save the result of the parallel tool calls for the history test.
let parallelToolCallsMessage: AIMessage | undefined;

/**
* Tests the basic functionality of invoking multiple tools in parallel.
* Verifies that the model can call both the weather and current time tools simultaneously.
*/
const invokeParallelTools = async () => {
const result: AIMessage = await modelWithTools.invoke(
callParallelToolsPrompt,
callOptions
);
// Model should call at least two tools. Using greater than or equal since it might call the current time tool multiple times.
expect(result.tool_calls?.length).toBeGreaterThanOrEqual(2);
if (!result.tool_calls?.length) return;

const weatherToolCalls = result.tool_calls.find(
(tc) => tc.name === weatherTool.name
);
const currentTimeToolCalls = result.tool_calls.find(
(tc) => tc.name === currentTimeTool.name
);

expect(weatherToolCalls).toBeDefined();
expect(currentTimeToolCalls).toBeDefined();
parallelToolCallsMessage = result;
};

/**
* Tests the model's ability to stream responses while making parallel tool calls.
* Ensures that the streamed result contains calls to both the weather and current time tools.
*/
const streamParallelTools = async () => {
const stream = await modelWithTools.stream(
callParallelToolsPrompt,
callOptions
);
let finalChunk: AIMessageChunk | undefined;
for await (const chunk of stream) {
finalChunk = !finalChunk ? chunk : concat(finalChunk, chunk);
}

expect(finalChunk).toBeDefined();
if (!finalChunk) return;

// Model should call at least two tools. Do not penalize for calling more than two tools, as
// long as it calls both the weather and current time tools.
expect(finalChunk.tool_calls?.length).toBeGreaterThanOrEqual(2);
if (!finalChunk.tool_calls?.length) return;

const weatherToolCalls = finalChunk.tool_calls.find(
(tc) => tc.name === weatherTool.name
);
const currentTimeToolCalls = finalChunk.tool_calls.find(
(tc) => tc.name === currentTimeTool.name
);

expect(weatherToolCalls).toBeDefined();
expect(currentTimeToolCalls).toBeDefined();
};

/**
* Tests the model's ability to process a message history containing parallel tool calls.
* Verifies that the model can generate a response based on previous tool calls without making unnecessary additional tool calls.
*/
const invokeParallelToolCallResultsInHistory = async () => {
const defaultAIMessageWithParallelTools = new AIMessage({
content: "",
tool_calls: [
{
name: weatherTool.name,
id: "get_current_weather_id",
args: { location: "San Francisco" },
},
{
name: currentTimeTool.name,
id: "get_current_time_id",
args: { location: "San Francisco" },
},
],
});
if (!parallelToolCallsMessage) {
// Allow this variable to be assigned in the first test, or if only run histories
// is passed, assign it here since the first test will not run.
parallelToolCallsMessage = defaultAIMessageWithParallelTools;
}
// Find the tool calls for the weather and current time tools so we can re-use the IDs in the message history.
const parallelToolCallWeather = parallelToolCallsMessage.tool_calls?.find(
(tc) => tc.name === weatherTool.name
);
const parallelToolCallCurrentTime =
parallelToolCallsMessage.tool_calls?.find(
(tc) => tc.name === currentTimeTool.name
);
if (!parallelToolCallWeather?.id || !parallelToolCallCurrentTime?.id) {
throw new Error(
`IDs not found in one of both of parallel tool calls:\nWeather ID: ${parallelToolCallWeather?.id}\nCurrent Time ID: ${parallelToolCallCurrentTime?.id}`
);
}

const messageHistory = [
new HumanMessage(callParallelToolsPrompt),
// The saved message from earlier when we called the model to generate the parallel tool calls.
parallelToolCallsMessage,
new ToolMessage({
name: weatherTool.name,
tool_call_id: parallelToolCallWeather.id,
content: "It is currently 24 degrees with hail in San Francisco.",
}),
new ToolMessage({
name: currentTimeTool.name,
tool_call_id: parallelToolCallCurrentTime.id,
content: "The current time in San Francisco is 12:02 PM.",
}),
];

const result: AIMessage = await modelWithTools.invoke(
messageHistory,
callOptions
);
// The model should NOT call a tool given this message history.
expect(result.tool_calls ?? []).toHaveLength(0);

if (typeof result.content === "string") {
expect(result.content).not.toBe("");
} else {
expect(result.content.length).toBeGreaterThan(0);
const textOrTextDeltaContent = result.content.find(
(c) => c.type === "text" || c.type === "text_delta"
);
expect(textOrTextDeltaContent).toBeDefined();
}
};

// Now we can invoke each of our tests synchronously, as the last test requires the result of the first test.
if (!onlyVerifyHistory) {
await invokeParallelTools();
await streamParallelTools();
}
await invokeParallelToolCallResultsInHistory();
}

/**
* Run all unit tests for the chat model.
* Each test is wrapped in a try/catch block to prevent the entire test suite from failing.
Expand All @@ -1324,42 +1531,42 @@

try {
await this.testInvoke();
} catch (e: any) {

Check warning on line 1534 in libs/langchain-standard-tests/src/integration_tests/chat_models.ts

View workflow job for this annotation

GitHub Actions / Check linting

Unexpected any. Specify a different type
allTestsPassed = false;
console.error("testInvoke failed", e.message);
}

try {
await this.testStream();
} catch (e: any) {

Check warning on line 1541 in libs/langchain-standard-tests/src/integration_tests/chat_models.ts

View workflow job for this annotation

GitHub Actions / Check linting

Unexpected any. Specify a different type
allTestsPassed = false;
console.error("testStream failed", e.message);
}

try {
await this.testBatch();
} catch (e: any) {

Check warning on line 1548 in libs/langchain-standard-tests/src/integration_tests/chat_models.ts

View workflow job for this annotation

GitHub Actions / Check linting

Unexpected any. Specify a different type
allTestsPassed = false;
console.error("testBatch failed", e.message);
}

try {
await this.testConversation();
} catch (e: any) {

Check warning on line 1555 in libs/langchain-standard-tests/src/integration_tests/chat_models.ts

View workflow job for this annotation

GitHub Actions / Check linting

Unexpected any. Specify a different type
allTestsPassed = false;
console.error("testConversation failed", e.message);
}

try {
await this.testUsageMetadata();
} catch (e: any) {

Check warning on line 1562 in libs/langchain-standard-tests/src/integration_tests/chat_models.ts

View workflow job for this annotation

GitHub Actions / Check linting

Unexpected any. Specify a different type
allTestsPassed = false;
console.error("testUsageMetadata failed", e.message);
}

try {
await this.testUsageMetadataStreaming();
} catch (e: any) {

Check warning on line 1569 in libs/langchain-standard-tests/src/integration_tests/chat_models.ts

View workflow job for this annotation

GitHub Actions / Check linting

Unexpected any. Specify a different type
allTestsPassed = false;
console.error("testUsageMetadataStreaming failed", e.message);
}
Expand Down Expand Up @@ -1451,6 +1658,13 @@
console.error("testInvokeMoreComplexTools failed", e.message);
}

try {
await this.testParallelToolCalling();
} catch (e: any) {
allTestsPassed = false;
console.error("testParallelToolCalling failed", e.message);
}

return allTestsPassed;
}
}
Loading