From 9a8f6e5e5d806407c4eebff2bebc8dac7909b6d7 Mon Sep 17 00:00:00 2001 From: Cai GoGwilt Date: Fri, 5 Apr 2024 13:14:52 -0700 Subject: [PATCH] Adds tool support for Claude --- .../core/src/plugins/anthropic/anthropic.ts | 85 +++++++++- .../anthropic/nodes/ChatAnthropicNode.ts | 151 ++++++++++++++++-- 2 files changed, 219 insertions(+), 17 deletions(-) diff --git a/packages/core/src/plugins/anthropic/anthropic.ts b/packages/core/src/plugins/anthropic/anthropic.ts index fbc8954ac..9fa415b1d 100644 --- a/packages/core/src/plugins/anthropic/anthropic.ts +++ b/packages/core/src/plugins/anthropic/anthropic.ts @@ -80,16 +80,39 @@ export type Claude3ChatMessage = { content: string | Claude3ChatMessageContentPart[]; } -export type Claude3ChatMessageContentPart = { - type: 'text' | 'image'; - text?: string; - source?: { +export type Claude3ChatMessageTextContentPart = { + type: 'text'; + text: string; +}; + +export type Claude3ChatMessageImageContentPart = { + type: 'image'; + source: { type: 'base64'; media_type: string; data: string; }; }; +export type Claude3ChatMessageToolResultContentPart = { + type: 'tool_result'; + tool_use_id: string; + content: string | { type: 'text'; text: string; }[]; +}; + +export type Claude3ChatMessageToolUseContentPart = { + type: 'tool_use'; + id: string; + name: string; + input: object; +} + +export type Claude3ChatMessageContentPart = + | Claude3ChatMessageTextContentPart + | Claude3ChatMessageImageContentPart + | Claude3ChatMessageToolResultContentPart + | Claude3ChatMessageToolUseContentPart; + export type ChatMessageOptions = { apiKey: string; model: AnthropicModels; @@ -102,6 +125,11 @@ export type ChatMessageOptions = { top_k?: number; signal?: AbortSignal; stream?: boolean; + tools?: { + name: string; + description: string; + input_schema: object; + }[]; }; export type ChatCompletionOptions = { @@ -168,7 +196,25 @@ export type ChatMessageChunk = { } } | { type: 'message_stop'; -} +}; + +export type ChatMessageResponse = { + id: string; + content: ({ + text: string; + } | { + id: string; + name: string; + input: object; + })[]; + model: string; + stop_reason: 'end_turn'; + stop_sequence: string; + usage: { + input_tokens: number; + output_tokens: number; + }; +}; export async function* streamChatCompletions({ apiKey, @@ -220,6 +266,35 @@ export async function* streamChatCompletions({ } } +export async function callMessageApi({ + apiKey, + signal, + tools, + ...rest +}: ChatMessageOptions): Promise { + const defaultSignal = new AbortController().signal; + const response = await fetch('https://api.anthropic.com/v1/messages', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + 'x-api-key': apiKey, + 'anthropic-version': '2023-06-01', + 'anthropic-beta': tools ? 'tools-2024-04-04' : 'messages-2023-12-15', + }, + body: JSON.stringify({ + ...rest, + tools, + stream: false, + }), + signal: signal ?? defaultSignal, + }); + const responseJson = await response.json(); + if (response.status !== 200) { + throw new AnthropicError(responseJson?.error?.message ?? 'Request failed', response, responseJson); + } + return responseJson; +} + export async function* streamMessageApi({ apiKey, signal, diff --git a/packages/core/src/plugins/anthropic/nodes/ChatAnthropicNode.ts b/packages/core/src/plugins/anthropic/nodes/ChatAnthropicNode.ts index cc13b84a4..bcef4d499 100644 --- a/packages/core/src/plugins/anthropic/nodes/ChatAnthropicNode.ts +++ b/packages/core/src/plugins/anthropic/nodes/ChatAnthropicNode.ts @@ -25,7 +25,9 @@ import { type Claude3ChatMessage, type Claude3ChatMessageContentPart, streamMessageApi, - type ChatMessageOptions + type ChatMessageOptions, + callMessageApi, + type Claude3ChatMessageTextContentPart, } from '../anthropic.js'; import { nanoid } from 'nanoid/non-secure'; import { dedent } from 'ts-dedent'; @@ -50,6 +52,7 @@ export type ChatAnthropicNodeConfigData = { top_k?: number; maxTokens: number; stop?: string; + enableToolUse?: boolean; }; export type ChatAnthropicNodeData = ChatAnthropicNodeConfigData & { @@ -107,6 +110,8 @@ export const ChatAnthropicNodeImpl: PluginNodeImpl = { cache: false, useAsGraphPartialOutput: true, + + enableToolUse: false, }, }; @@ -173,6 +178,16 @@ export const ChatAnthropicNodeImpl: PluginNodeImpl = { }); } + if (data.enableToolUse) { + inputs.push({ + dataType: ['gpt-function', 'gpt-function[]'] as const, + id: 'tools' as PortId, + title: 'Tools', + description: 'Tools to use in the model. To connect multiple tools, use an Array node.', + coerced: false, + }); + } + inputs.push({ dataType: ['chat-message', 'chat-message[]'] as const, id: 'prompt' as PortId, @@ -182,7 +197,7 @@ export const ChatAnthropicNodeImpl: PluginNodeImpl = { return inputs; }, - getOutputDefinitions(_data): NodeOutputDefinition[] { + getOutputDefinitions(data): NodeOutputDefinition[] { const outputs: NodeOutputDefinition[] = []; outputs.push({ @@ -191,6 +206,22 @@ export const ChatAnthropicNodeImpl: PluginNodeImpl = { title: 'Response', }); + if (data.enableToolUse) { + outputs.push({ + dataType: 'object[]', + id: 'function-calls' as PortId, + title: 'Function Calls', + description: 'The function calls that were made, if any.', + }); + } + + outputs.push({ + dataType: 'chat-message[]', + id: 'all-messages' as PortId, + title: 'All Messages', + description: 'All messages, with the response appended.', + }) + return outputs; }, @@ -266,6 +297,11 @@ export const ChatAnthropicNodeImpl: PluginNodeImpl = { label: 'Use for subgraph partial output', dataKey: 'useAsGraphPartialOutput', }, + { + type: 'toggle', + label: 'Enable Tool Use (disables streaming)', + dataKey: 'enableToolUse', + }, ]; }, @@ -298,9 +334,13 @@ export const ChatAnthropicNodeImpl: PluginNodeImpl = { ? coerceTypeOptional(inputs['stop' as PortId], 'string') ?? data.stop : undefined : data.stop; - const { messages } = await getChatAnthropicNodeMessages(inputs); + const tools = data.enableToolUse ? (coerceTypeOptional(inputs['tools' as PortId], 'gpt-function[]') ?? []) : undefined; + const rivetChatMessages = getChatMessages(inputs); + const messages = await chatMessagesToClaude3ChatMessages(rivetChatMessages); let prompt = messages.reduce((acc, message) => { - const content = typeof message.content === 'string' ? message.content : message.content.map((c) => c.text ?? '').join(''); + const content = typeof message.content === 'string' + ? message.content + : message.content.filter((c): c is Claude3ChatMessageTextContentPart => c.type === 'text').map((c) => c.text ?? '').join(''); if (message.role === 'user') { return `${acc}\n\nHuman: ${content}`; } else if (message.role === 'assistant') { @@ -360,6 +400,7 @@ export const ChatAnthropicNodeImpl: PluginNodeImpl = { stop_sequences: stop ? [stop] : undefined, system: system, messages, + tools: tools ? tools.map((tool) => ({ name: tool.name, description: tool.description, input_schema: tool.parameters })) : undefined, }; const useMessageApi = model.startsWith('claude-3'); const cacheKey = JSON.stringify(useMessageApi ? messageOptions : completionOptions); @@ -373,7 +414,50 @@ export const ChatAnthropicNodeImpl: PluginNodeImpl = { const startTime = Date.now(); const apiKey = context.getPluginConfig('anthropicApiKey'); - if (useMessageApi) { + if (useMessageApi && data.enableToolUse) { + // Streaming is not supported with tool usage. + const response = await callMessageApi({ + apiKey: apiKey ?? '', + ...messageOptions, + }); + const { input_tokens: requestTokens, output_tokens: responseTokens } = response.usage; + const responseText = response.content.map((c): string | undefined => (c as any).text).filter(isNotNull).join(''); + output['response' as PortId] = { + type: 'string', + value: responseText, + }; + const functionCalls = response.content + .filter((content) => (content as any).name && (content as any).id) + .map((functionCall: any) => ({ + name: functionCall.name, + arguments: functionCall.input, // Matches OpenAI ChatNode + id: functionCall.id, + })); + if (functionCalls.length > 0) { + output['function-calls' as PortId] = { + type: 'object[]', + value: functionCalls, + }; + } + output['all-messages' as PortId] = { + type: 'chat-message[]', + value: [ + ...rivetChatMessages, + { + type: 'assistant', + message: responseText, + function_call: functionCalls.length > 0 ? functionCalls.map((toolCall) => ({ + name: toolCall.name, + arguments: JSON.stringify(toolCall.arguments), + id: toolCall.id, + }))[0] : undefined, + } satisfies ChatMessage, + ] + }; + output['requestTokens' as PortId] = { type: 'number', value: requestTokens ?? tokenCountEstimate }; + const responseTokenCount = responseTokens ?? context.tokenizer.getTokenCountForString(responseText, tokenizerInfo); + output['responseTokens' as PortId] = { type: 'number', value: responseTokenCount }; + } else if (useMessageApi) { // Use the messages API for Claude 3 models const chunks = streamMessageApi({ apiKey: apiKey ?? '', @@ -403,6 +487,17 @@ export const ChatAnthropicNodeImpl: PluginNodeImpl = { type: 'string', value: responseParts.join('').trim(), }; + output['all-messages' as PortId] = { + type: 'chat-message[]', + value: [ + ...rivetChatMessages, + { + type: 'assistant', + message: responseParts.join('').trim(), + function_call: undefined, + } satisfies ChatMessage, + ] + }; context.onPartialOutputs?.(output); } @@ -439,6 +534,17 @@ export const ChatAnthropicNodeImpl: PluginNodeImpl = { throw new Error('No response from Anthropic'); } + output['all-messages' as PortId] = { + type: 'chat-message[]', + value: [ + ...rivetChatMessages, + { + type: 'assistant', + message: responseParts.join('').trim(), + function_call: undefined, + } satisfies ChatMessage, + ] + }; output['requestTokens' as PortId] = { type: 'number', value: tokenCountEstimate }; const responseTokenCount = context.tokenizer.getTokenCountForString(responseParts.join(''), tokenizerInfo); output['responseTokens' as PortId] = { type: 'number', value: responseTokenCount }; @@ -515,12 +621,12 @@ export function getSystemPrompt(inputs: Inputs) { } } -export async function getChatAnthropicNodeMessages(inputs: Inputs) { +function getChatMessages(inputs: Inputs) { const prompt = inputs['prompt' as PortId]; if (!prompt) { throw new Error('Prompt is required'); } - + const chatMessages = match(prompt) .with({ type: 'chat-message' }, (p) => [p.value]) .with({ type: 'chat-message[]' }, (p) => p.value) @@ -537,23 +643,26 @@ export async function getChatAnthropicNodeMessages(inputs: Inputs) { 'string', ), ); - + return stringValues.filter((v) => v != null).map((v) => ({ type: 'user', message: v })); } - + const coercedMessage = coerceType(p, 'chat-message'); if (coercedMessage != null) { return [coercedMessage]; } - + const coercedString = coerceType(p, 'string'); return coercedString != null ? [{ type: 'user', message: coerceType(p, 'string') }] : []; }); + return chatMessages; +} +export async function chatMessagesToClaude3ChatMessages(chatMessages: ChatMessage[]): Promise { const messages: Claude3ChatMessage[] = (await Promise.all(chatMessages.map(chatMessageToClaude3ChatMessage))).filter(isNotNull); - return { messages }; + return messages; } async function chatMessageToClaude3ChatMessage(message: ChatMessage): Promise { @@ -561,9 +670,26 @@ async function chatMessageToClaude3ChatMessage(message: ChatMessage): Promise typeof m === 'string' ? { type: 'text' as const, text: m } : undefined).filter(isNotNull); + return { + role: 'user', + content: [{ + type: 'tool_result', + tool_use_id: message.name, + content: content.length === 1 ? content[0]!.text : content, + }], + }; } const content = Array.isArray(message.message) ? await Promise.all(message.message.map(chatMessageContentToClaude3ChatMessage)) : [await chatMessageContentToClaude3ChatMessage(message.message)]; + if (message.type === 'assistant' && message.function_call) { + content.push({ + type: 'tool_use', + id: message.function_call.id!, + name: message.function_call.name, + input: JSON.parse(message.function_call.arguments), + }); + } return { role: message.type, content, @@ -593,6 +719,7 @@ async function chatMessageContentToClaude3ChatMessage(content: ChatMessageMessag assertNever(content); } } + function getCostForTokens(tokenCounts: { requestTokens: number; responseTokens: number; }, model: AnthropicModels): number | undefined { const modelInfo = anthropicModels[model]; if (modelInfo == null) {