Skip to content

Commit

Permalink
Fix parallel function calling with claude
Browse files Browse the repository at this point in the history
  • Loading branch information
abrenneke committed Jul 31, 2024
1 parent 4c51473 commit c0c0f49
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 2 deletions.
42 changes: 40 additions & 2 deletions packages/core/src/plugins/anthropic/nodes/ChatAnthropicNode.ts
Original file line number Diff line number Diff line change
Expand Up @@ -444,12 +444,14 @@ export const ChatAnthropicNodeImpl: PluginNodeImpl<ChatAnthropicNode> = {
arguments: functionCall.input, // Matches OpenAI ChatNode
id: functionCall.id,
}));

if (functionCalls.length > 0) {
output['function-calls' as PortId] = {
type: 'object[]',
value: functionCalls,
};
}

output['all-messages' as PortId] = {
type: 'chat-message[]',
value: [
Expand All @@ -465,6 +467,11 @@ export const ChatAnthropicNodeImpl: PluginNodeImpl<ChatAnthropicNode> = {
id: toolCall.id,
}))[0]
: undefined,
function_calls: functionCalls.map((toolCall) => ({
name: toolCall.name,
arguments: JSON.stringify(toolCall.arguments),
id: toolCall.id,
})),
} satisfies ChatMessage,
],
};
Expand Down Expand Up @@ -511,6 +518,7 @@ export const ChatAnthropicNodeImpl: PluginNodeImpl<ChatAnthropicNode> = {
type: 'assistant',
message: responseParts.join('').trim(),
function_call: undefined,
function_calls: undefined,
} satisfies ChatMessage,
],
};
Expand Down Expand Up @@ -559,6 +567,7 @@ export const ChatAnthropicNodeImpl: PluginNodeImpl<ChatAnthropicNode> = {
type: 'assistant',
message: responseParts.join('').trim(),
function_call: undefined,
function_calls: undefined,
} satisfies ChatMessage,
],
};
Expand Down Expand Up @@ -684,7 +693,25 @@ export async function chatMessagesToClaude3ChatMessages(chatMessages: ChatMessag
isNotNull,
);

return messages;
// Combine sequential tool_result messages into a single user message with multiple content items
const combinedMessages = messages.reduce<Claude3ChatMessage[]>((acc, message) => {
if (
message.role === 'user' &&
Array.isArray(message.content) &&
message.content.length === 1 &&
message.content[0]!.type === 'tool_result'
) {
const last = acc.at(-1);
if (last?.role === 'user' && Array.isArray(last.content) && last.content.every((c) => c.type === 'tool_result')) {
const content = last.content.concat(message.content);
return [...acc.slice(0, -1), { ...last, content }];
}
}

return [...acc, message];
}, []);

return combinedMessages;
}

async function chatMessageToClaude3ChatMessage(message: ChatMessage): Promise<Claude3ChatMessage | undefined> {
Expand All @@ -707,10 +734,21 @@ async function chatMessageToClaude3ChatMessage(message: ChatMessage): Promise<Cl
],
};
}

const content = Array.isArray(message.message)
? await Promise.all(message.message.map(chatMessageContentToClaude3ChatMessage))
: [await chatMessageContentToClaude3ChatMessage(message.message)];
if (message.type === 'assistant' && message.function_call) {

if (message.type === 'assistant' && message.function_calls) {
content.push(
...message.function_calls.map((fc) => ({
type: 'tool_use' as const,
id: fc.id!,
name: fc.name,
input: JSON.parse(fc.arguments),
})),
);
} else if (message.type === 'assistant' && message.function_call) {
content.push({
type: 'tool_use',
id: message.function_call.id!,
Expand Down
1 change: 1 addition & 0 deletions packages/core/src/plugins/google/nodes/ChatGoogleNode.ts
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,7 @@ export const ChatGoogleNodeImpl: PluginNodeImpl<ChatGoogleNode> = {
type: 'assistant',
message: responseParts.join('').trim() ?? '',
function_call: undefined,
function_calls: undefined,
},
],
};
Expand Down

0 comments on commit c0c0f49

Please sign in to comment.