Skip to content

Commit

Permalink
aws[patch]: Fix fails when calling multiple tools simultaneously (#6175)
Browse files Browse the repository at this point in the history
* aws[patch]:  Fix fails when calling multiple tools simultaneously

* adding test cases

---------

Co-authored-by: Brace Sproul <[email protected]>
  • Loading branch information
tinque and bracesproul authored Jul 23, 2024
1 parent 99a8760 commit 87d92d9
Show file tree
Hide file tree
Showing 2 changed files with 310 additions and 60 deletions.
24 changes: 23 additions & 1 deletion libs/langchain-aws/src/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,29 @@ export function convertToConverseMessages(messages: BaseMessage[]): {
}
});

return { converseMessages, converseSystem };
// Combine consecutive user tool result messages into a single message
const combinedConverseMessages = converseMessages.reduce<BedrockMessage[]>(
(acc, curr) => {
const lastMessage = acc[acc.length - 1];

if (
lastMessage &&
lastMessage.role === "user" &&
lastMessage.content?.some((c) => "toolResult" in c) &&
curr.role === "user" &&
curr.content?.some((c) => "toolResult" in c)
) {
lastMessage.content = lastMessage.content.concat(curr.content);
} else {
acc.push(curr);
}

return acc;
},
[]
);

return { converseMessages: combinedConverseMessages, converseSystem };
}

export function isBedrockTool(tool: unknown): tool is BedrockTool {
Expand Down
346 changes: 287 additions & 59 deletions libs/langchain-aws/src/tests/chat_models.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,77 +4,305 @@ import {
AIMessage,
ToolMessage,
AIMessageChunk,
BaseMessage,
} from "@langchain/core/messages";
import { concat } from "@langchain/core/utils/stream";
import type {
Message as BedrockMessage,
SystemContentBlock as BedrockSystemContentBlock,
} from "@aws-sdk/client-bedrock-runtime";
import {
convertToConverseMessages,
handleConverseStreamContentBlockDelta,
} from "../common.js";

test("convertToConverseMessages works", () => {
const messages = [
new SystemMessage("You're an advanced AI assistant."),
new HumanMessage(
"What's the weather like today in Berkeley, CA? Use weather.com to check."
),
new AIMessage({
content: "",
tool_calls: [
{
name: "retrieverTool",
args: {
url: "https://weather.com",
describe("convertToConverseMessages", () => {
const testCases: {
name: string;
input: BaseMessage[];
output: {
converseMessages: BedrockMessage[];
converseSystem: BedrockSystemContentBlock[];
};
}[] = [
{
name: "empty input",
input: [],
output: {
converseMessages: [],
converseSystem: [],
},
},
{
name: "simple messages",
input: [
new SystemMessage("You're an advanced AI assistant."),
new HumanMessage(
"What's the weather like today in Berkeley, CA? Use weather.com to check."
),
new AIMessage({
content: "",
tool_calls: [
{
name: "retrieverTool",
args: {
url: "https://weather.com",
},
id: "123_retriever_tool",
},
],
}),
new ToolMessage({
tool_call_id: "123_retriever_tool",
content: "The weather in Berkeley, CA is 70 degrees and sunny.",
}),
],
output: {
converseMessages: [
{
role: "user",
content: [
{
text: "What's the weather like today in Berkeley, CA? Use weather.com to check.",
},
],
},
{
role: "assistant",
content: [
{
toolUse: {
name: "retrieverTool",
toolUseId: "123_retriever_tool",
input: {
url: "https://weather.com",
},
},
},
],
},
{
role: "user",
content: [
{
toolResult: {
toolUseId: "123_retriever_tool",
content: [
{
text: "The weather in Berkeley, CA is 70 degrees and sunny.",
},
],
},
},
],
},
id: "123_retriever_tool",
},
],
converseSystem: [
{
text: "You're an advanced AI assistant.",
},
],
},
},
{
name: "consecutive user tool messages",
input: [
new SystemMessage("You're an advanced AI assistant."),
new HumanMessage(
"What's the weather like today in Berkeley, CA and in Paris, France? Use weather.com to check."
),
new AIMessage({
content: "",
tool_calls: [
{
name: "retrieverTool",
args: {
url: "https://weather.com",
},
id: "123_retriever_tool",
},
{
name: "retrieverTool",
args: {
url: "https://weather.com",
},
id: "456_retriever_tool",
},
],
}),
new ToolMessage({
tool_call_id: "123_retriever_tool",
content: "The weather in Berkeley, CA is 70 degrees and sunny.",
}),
new ToolMessage({
tool_call_id: "456_retriever_tool",
content: "The weather in Paris, France is perfect.",
}),
new HumanMessage(
"What's the weather like today in Berkeley, CA and in Paris, France? Use meteofrance.com to check."
),
new AIMessage({
content: "",
tool_calls: [
{
name: "retrieverTool",
args: {
url: "https://meteofrance.com",
},
id: "321_retriever_tool",
},
{
name: "retrieverTool",
args: {
url: "https://meteofrance.com",
},
id: "654_retriever_tool",
},
],
}),
new ToolMessage({
tool_call_id: "321_retriever_tool",
content: "Why don't you check yourself?",
}),
new ToolMessage({
tool_call_id: "654_retriever_tool",
content: "The weather in Paris, France is horrible.",
}),
],
}),
new ToolMessage({
tool_call_id: "123_retriever_tool",
content: "The weather in Berkeley, CA is 70 degrees and sunny.",
}),
output: {
converseSystem: [
{
text: "You're an advanced AI assistant.",
},
],
converseMessages: [
{
role: "user",
content: [
{
text: "What's the weather like today in Berkeley, CA and in Paris, France? Use weather.com to check.",
},
],
},
{
role: "assistant",
content: [
{
toolUse: {
name: "retrieverTool",
toolUseId: "123_retriever_tool",
input: {
url: "https://weather.com",
},
},
},
{
toolUse: {
name: "retrieverTool",
toolUseId: "456_retriever_tool",
input: {
url: "https://weather.com",
},
},
},
],
},
{
role: "user",
content: [
{
toolResult: {
toolUseId: "123_retriever_tool",
content: [
{
text: "The weather in Berkeley, CA is 70 degrees and sunny.",
},
],
},
},
{
toolResult: {
toolUseId: "456_retriever_tool",
content: [
{
text: "The weather in Paris, France is perfect.",
},
],
},
},
],
},
{
role: "user",
content: [
{
text: "What's the weather like today in Berkeley, CA and in Paris, France? Use meteofrance.com to check.",
},
],
},
{
role: "assistant",
content: [
{
toolUse: {
name: "retrieverTool",
toolUseId: "321_retriever_tool",
input: {
url: "https://meteofrance.com",
},
},
},
{
toolUse: {
name: "retrieverTool",
toolUseId: "654_retriever_tool",
input: {
url: "https://meteofrance.com",
},
},
},
],
},
{
role: "user",
content: [
{
toolResult: {
toolUseId: "321_retriever_tool",
content: [
{
text: "Why don't you check yourself?",
},
],
},
},
{
toolResult: {
toolUseId: "654_retriever_tool",
content: [
{
text: "The weather in Paris, France is horrible.",
},
],
},
},
],
},
],
},
},
];

const { converseMessages, converseSystem } =
convertToConverseMessages(messages);

expect(converseSystem).toHaveLength(1);
expect(converseSystem[0].text).toBe("You're an advanced AI assistant.");

expect(converseMessages).toHaveLength(3);

const userMsgs = converseMessages.filter((msg) => msg.role === "user");
// Length of two because of the first user question, and tool use
// messages will have the user role.
expect(userMsgs).toHaveLength(2);
const textUserMsg = userMsgs.find((msg) => msg.content?.[0].text);
expect(textUserMsg?.content?.[0].text).toBe(
"What's the weather like today in Berkeley, CA? Use weather.com to check."
it.each(testCases.map((tc) => [tc.name, tc]))(
"convertToConverseMessages: case %s",
(_, tc) => {
const { converseMessages, converseSystem } = convertToConverseMessages(
tc.input
);
expect(converseMessages).toEqual(tc.output.converseMessages);
expect(converseSystem).toEqual(tc.output.converseSystem);
}
);

const toolUseUserMsg = userMsgs.find((msg) => msg.content?.[0].toolResult);
expect(toolUseUserMsg).toBeDefined();
expect(toolUseUserMsg?.content).toHaveLength(1);
if (!toolUseUserMsg?.content?.length) return;

const toolResultContent = toolUseUserMsg.content[0];
expect(toolResultContent).toBeDefined();
expect(toolResultContent.toolResult?.toolUseId).toBe("123_retriever_tool");
expect(toolResultContent.toolResult?.content?.[0].text).toBe(
"The weather in Berkeley, CA is 70 degrees and sunny."
);

const assistantMsg = converseMessages.find((msg) => msg.role === "assistant");
expect(assistantMsg).toBeDefined();
if (!assistantMsg) return;

const toolUseContent = assistantMsg.content?.find((c) => "toolUse" in c);
expect(toolUseContent).toBeDefined();
expect(toolUseContent?.toolUse?.name).toBe("retrieverTool");
expect(toolUseContent?.toolUse?.toolUseId).toBe("123_retriever_tool");
expect(toolUseContent?.toolUse?.input).toEqual({
url: "https://weather.com",
});
});

test("Streaming supports empty string chunks", async () => {
Expand Down

0 comments on commit 87d92d9

Please sign in to comment.