From 87d92d9740dd7a116d9721eaeda71efb6bbca577 Mon Sep 17 00:00:00 2001 From: Quentin GEORGET Date: Tue, 23 Jul 2024 21:17:41 +0200 Subject: [PATCH] aws[patch]: Fix fails when calling multiple tools simultaneously (#6175) * aws[patch]: Fix fails when calling multiple tools simultaneously * adding test cases --------- Co-authored-by: Brace Sproul --- libs/langchain-aws/src/common.ts | 24 +- .../src/tests/chat_models.test.ts | 346 +++++++++++++++--- 2 files changed, 310 insertions(+), 60 deletions(-) diff --git a/libs/langchain-aws/src/common.ts b/libs/langchain-aws/src/common.ts index 94630eda4a26..00dd87f2889c 100644 --- a/libs/langchain-aws/src/common.ts +++ b/libs/langchain-aws/src/common.ts @@ -209,7 +209,29 @@ export function convertToConverseMessages(messages: BaseMessage[]): { } }); - return { converseMessages, converseSystem }; + // Combine consecutive user tool result messages into a single message + const combinedConverseMessages = converseMessages.reduce( + (acc, curr) => { + const lastMessage = acc[acc.length - 1]; + + if ( + lastMessage && + lastMessage.role === "user" && + lastMessage.content?.some((c) => "toolResult" in c) && + curr.role === "user" && + curr.content?.some((c) => "toolResult" in c) + ) { + lastMessage.content = lastMessage.content.concat(curr.content); + } else { + acc.push(curr); + } + + return acc; + }, + [] + ); + + return { converseMessages: combinedConverseMessages, converseSystem }; } export function isBedrockTool(tool: unknown): tool is BedrockTool { diff --git a/libs/langchain-aws/src/tests/chat_models.test.ts b/libs/langchain-aws/src/tests/chat_models.test.ts index e944468f8e1c..22ccd342ebc0 100644 --- a/libs/langchain-aws/src/tests/chat_models.test.ts +++ b/libs/langchain-aws/src/tests/chat_models.test.ts @@ -4,77 +4,305 @@ import { AIMessage, ToolMessage, AIMessageChunk, + BaseMessage, } from "@langchain/core/messages"; import { concat } from "@langchain/core/utils/stream"; +import type { + Message as BedrockMessage, + SystemContentBlock as BedrockSystemContentBlock, +} from "@aws-sdk/client-bedrock-runtime"; import { convertToConverseMessages, handleConverseStreamContentBlockDelta, } from "../common.js"; -test("convertToConverseMessages works", () => { - const messages = [ - new SystemMessage("You're an advanced AI assistant."), - new HumanMessage( - "What's the weather like today in Berkeley, CA? Use weather.com to check." - ), - new AIMessage({ - content: "", - tool_calls: [ - { - name: "retrieverTool", - args: { - url: "https://weather.com", +describe("convertToConverseMessages", () => { + const testCases: { + name: string; + input: BaseMessage[]; + output: { + converseMessages: BedrockMessage[]; + converseSystem: BedrockSystemContentBlock[]; + }; + }[] = [ + { + name: "empty input", + input: [], + output: { + converseMessages: [], + converseSystem: [], + }, + }, + { + name: "simple messages", + input: [ + new SystemMessage("You're an advanced AI assistant."), + new HumanMessage( + "What's the weather like today in Berkeley, CA? Use weather.com to check." + ), + new AIMessage({ + content: "", + tool_calls: [ + { + name: "retrieverTool", + args: { + url: "https://weather.com", + }, + id: "123_retriever_tool", + }, + ], + }), + new ToolMessage({ + tool_call_id: "123_retriever_tool", + content: "The weather in Berkeley, CA is 70 degrees and sunny.", + }), + ], + output: { + converseMessages: [ + { + role: "user", + content: [ + { + text: "What's the weather like today in Berkeley, CA? Use weather.com to check.", + }, + ], + }, + { + role: "assistant", + content: [ + { + toolUse: { + name: "retrieverTool", + toolUseId: "123_retriever_tool", + input: { + url: "https://weather.com", + }, + }, + }, + ], + }, + { + role: "user", + content: [ + { + toolResult: { + toolUseId: "123_retriever_tool", + content: [ + { + text: "The weather in Berkeley, CA is 70 degrees and sunny.", + }, + ], + }, + }, + ], }, - id: "123_retriever_tool", - }, + ], + converseSystem: [ + { + text: "You're an advanced AI assistant.", + }, + ], + }, + }, + { + name: "consecutive user tool messages", + input: [ + new SystemMessage("You're an advanced AI assistant."), + new HumanMessage( + "What's the weather like today in Berkeley, CA and in Paris, France? Use weather.com to check." + ), + new AIMessage({ + content: "", + tool_calls: [ + { + name: "retrieverTool", + args: { + url: "https://weather.com", + }, + id: "123_retriever_tool", + }, + { + name: "retrieverTool", + args: { + url: "https://weather.com", + }, + id: "456_retriever_tool", + }, + ], + }), + new ToolMessage({ + tool_call_id: "123_retriever_tool", + content: "The weather in Berkeley, CA is 70 degrees and sunny.", + }), + new ToolMessage({ + tool_call_id: "456_retriever_tool", + content: "The weather in Paris, France is perfect.", + }), + new HumanMessage( + "What's the weather like today in Berkeley, CA and in Paris, France? Use meteofrance.com to check." + ), + new AIMessage({ + content: "", + tool_calls: [ + { + name: "retrieverTool", + args: { + url: "https://meteofrance.com", + }, + id: "321_retriever_tool", + }, + { + name: "retrieverTool", + args: { + url: "https://meteofrance.com", + }, + id: "654_retriever_tool", + }, + ], + }), + new ToolMessage({ + tool_call_id: "321_retriever_tool", + content: "Why don't you check yourself?", + }), + new ToolMessage({ + tool_call_id: "654_retriever_tool", + content: "The weather in Paris, France is horrible.", + }), ], - }), - new ToolMessage({ - tool_call_id: "123_retriever_tool", - content: "The weather in Berkeley, CA is 70 degrees and sunny.", - }), + output: { + converseSystem: [ + { + text: "You're an advanced AI assistant.", + }, + ], + converseMessages: [ + { + role: "user", + content: [ + { + text: "What's the weather like today in Berkeley, CA and in Paris, France? Use weather.com to check.", + }, + ], + }, + { + role: "assistant", + content: [ + { + toolUse: { + name: "retrieverTool", + toolUseId: "123_retriever_tool", + input: { + url: "https://weather.com", + }, + }, + }, + { + toolUse: { + name: "retrieverTool", + toolUseId: "456_retriever_tool", + input: { + url: "https://weather.com", + }, + }, + }, + ], + }, + { + role: "user", + content: [ + { + toolResult: { + toolUseId: "123_retriever_tool", + content: [ + { + text: "The weather in Berkeley, CA is 70 degrees and sunny.", + }, + ], + }, + }, + { + toolResult: { + toolUseId: "456_retriever_tool", + content: [ + { + text: "The weather in Paris, France is perfect.", + }, + ], + }, + }, + ], + }, + { + role: "user", + content: [ + { + text: "What's the weather like today in Berkeley, CA and in Paris, France? Use meteofrance.com to check.", + }, + ], + }, + { + role: "assistant", + content: [ + { + toolUse: { + name: "retrieverTool", + toolUseId: "321_retriever_tool", + input: { + url: "https://meteofrance.com", + }, + }, + }, + { + toolUse: { + name: "retrieverTool", + toolUseId: "654_retriever_tool", + input: { + url: "https://meteofrance.com", + }, + }, + }, + ], + }, + { + role: "user", + content: [ + { + toolResult: { + toolUseId: "321_retriever_tool", + content: [ + { + text: "Why don't you check yourself?", + }, + ], + }, + }, + { + toolResult: { + toolUseId: "654_retriever_tool", + content: [ + { + text: "The weather in Paris, France is horrible.", + }, + ], + }, + }, + ], + }, + ], + }, + }, ]; - const { converseMessages, converseSystem } = - convertToConverseMessages(messages); - - expect(converseSystem).toHaveLength(1); - expect(converseSystem[0].text).toBe("You're an advanced AI assistant."); - - expect(converseMessages).toHaveLength(3); - - const userMsgs = converseMessages.filter((msg) => msg.role === "user"); - // Length of two because of the first user question, and tool use - // messages will have the user role. - expect(userMsgs).toHaveLength(2); - const textUserMsg = userMsgs.find((msg) => msg.content?.[0].text); - expect(textUserMsg?.content?.[0].text).toBe( - "What's the weather like today in Berkeley, CA? Use weather.com to check." + it.each(testCases.map((tc) => [tc.name, tc]))( + "convertToConverseMessages: case %s", + (_, tc) => { + const { converseMessages, converseSystem } = convertToConverseMessages( + tc.input + ); + expect(converseMessages).toEqual(tc.output.converseMessages); + expect(converseSystem).toEqual(tc.output.converseSystem); + } ); - - const toolUseUserMsg = userMsgs.find((msg) => msg.content?.[0].toolResult); - expect(toolUseUserMsg).toBeDefined(); - expect(toolUseUserMsg?.content).toHaveLength(1); - if (!toolUseUserMsg?.content?.length) return; - - const toolResultContent = toolUseUserMsg.content[0]; - expect(toolResultContent).toBeDefined(); - expect(toolResultContent.toolResult?.toolUseId).toBe("123_retriever_tool"); - expect(toolResultContent.toolResult?.content?.[0].text).toBe( - "The weather in Berkeley, CA is 70 degrees and sunny." - ); - - const assistantMsg = converseMessages.find((msg) => msg.role === "assistant"); - expect(assistantMsg).toBeDefined(); - if (!assistantMsg) return; - - const toolUseContent = assistantMsg.content?.find((c) => "toolUse" in c); - expect(toolUseContent).toBeDefined(); - expect(toolUseContent?.toolUse?.name).toBe("retrieverTool"); - expect(toolUseContent?.toolUse?.toolUseId).toBe("123_retriever_tool"); - expect(toolUseContent?.toolUse?.input).toEqual({ - url: "https://weather.com", - }); }); test("Streaming supports empty string chunks", async () => {