Skip to content

Commit

Permalink
Default ChatAnthropic to sonnet 3.5
Browse files Browse the repository at this point in the history
  • Loading branch information
abrenneke committed Jul 18, 2024
1 parent 0b0f5fd commit 1f61c1e
Showing 1 changed file with 81 additions and 49 deletions.
130 changes: 81 additions & 49 deletions packages/core/src/plugins/anthropic/nodes/ChatAnthropicNode.ts
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ export const ChatAnthropicNodeImpl: PluginNodeImpl<ChatAnthropicNode> = {
width: 275,
},
data: {
model: 'claude-2',
model: 'claude-3-5-sonnet-20240620',
useModelInput: false,

temperature: 0.5,
Expand Down Expand Up @@ -220,7 +220,7 @@ export const ChatAnthropicNodeImpl: PluginNodeImpl<ChatAnthropicNode> = {
id: 'all-messages' as PortId,
title: 'All Messages',
description: 'All messages, with the response appended.',
})
});

return outputs;
},
Expand Down Expand Up @@ -334,13 +334,19 @@ export const ChatAnthropicNodeImpl: PluginNodeImpl<ChatAnthropicNode> = {
? coerceTypeOptional(inputs['stop' as PortId], 'string') ?? data.stop
: undefined
: data.stop;
const tools = data.enableToolUse ? (coerceTypeOptional(inputs['tools' as PortId], 'gpt-function[]') ?? []) : undefined;
const tools = data.enableToolUse
? coerceTypeOptional(inputs['tools' as PortId], 'gpt-function[]') ?? []
: undefined;
const rivetChatMessages = getChatMessages(inputs);
const messages = await chatMessagesToClaude3ChatMessages(rivetChatMessages);
let prompt = messages.reduce((acc, message) => {
const content = typeof message.content === 'string'
? message.content
: message.content.filter((c): c is Claude3ChatMessageTextContentPart => c.type === 'text').map((c) => c.text ?? '').join('');
const content =
typeof message.content === 'string'
? message.content
: message.content
.filter((c): c is Claude3ChatMessageTextContentPart => c.type === 'text')
.map((c) => c.text ?? '')
.join('');
if (message.role === 'user') {
return `${acc}\n\nHuman: ${content}`;
} else if (message.role === 'assistant') {
Expand All @@ -349,10 +355,10 @@ export const ChatAnthropicNodeImpl: PluginNodeImpl<ChatAnthropicNode> = {
return acc;
}, '');
prompt += '\n\nAssistant:';

// Get the "System" prompt input for Claude 3 models
const system = data.model.startsWith('claude-3') ? getSystemPrompt(inputs) : undefined;

let { maxTokens } = data;
const tokenizerInfo: TokenizerCallInfo = {
node: context.node,
Expand Down Expand Up @@ -400,7 +406,9 @@ export const ChatAnthropicNodeImpl: PluginNodeImpl<ChatAnthropicNode> = {
stop_sequences: stop ? [stop] : undefined,
system: system,
messages,
tools: tools ? tools.map((tool) => ({ name: tool.name, description: tool.description, input_schema: tool.parameters })) : undefined,
tools: tools
? tools.map((tool) => ({ name: tool.name, description: tool.description, input_schema: tool.parameters }))
: undefined,
};
const useMessageApi = model.startsWith('claude-3');
const cacheKey = JSON.stringify(useMessageApi ? messageOptions : completionOptions);
Expand All @@ -410,18 +418,21 @@ export const ChatAnthropicNodeImpl: PluginNodeImpl<ChatAnthropicNode> = {
return cached;
}
}

const startTime = Date.now();
const apiKey = context.getPluginConfig('anthropicApiKey');

if (useMessageApi && data.enableToolUse) {
// Streaming is not supported with tool usage.
const response = await callMessageApi({
apiKey: apiKey ?? '',
...messageOptions,
});
const { input_tokens: requestTokens, output_tokens: responseTokens } = response.usage;
const responseText = response.content.map((c): string | undefined => (c as any).text).filter(isNotNull).join('');
const responseText = response.content
.map((c): string | undefined => (c as any).text)
.filter(isNotNull)
.join('');
output['response' as PortId] = {
type: 'string',
value: responseText,
Expand All @@ -446,16 +457,20 @@ export const ChatAnthropicNodeImpl: PluginNodeImpl<ChatAnthropicNode> = {
{
type: 'assistant',
message: responseText,
function_call: functionCalls.length > 0 ? functionCalls.map((toolCall) => ({
name: toolCall.name,
arguments: JSON.stringify(toolCall.arguments),
id: toolCall.id,
}))[0] : undefined,
function_call:
functionCalls.length > 0
? functionCalls.map((toolCall) => ({
name: toolCall.name,
arguments: JSON.stringify(toolCall.arguments),
id: toolCall.id,
}))[0]
: undefined,
} satisfies ChatMessage,
]
],
};
output['requestTokens' as PortId] = { type: 'number', value: requestTokens ?? tokenCountEstimate };
const responseTokenCount = responseTokens ?? context.tokenizer.getTokenCountForString(responseText, tokenizerInfo);
const responseTokenCount =
responseTokens ?? context.tokenizer.getTokenCountForString(responseText, tokenizerInfo);
output['responseTokens' as PortId] = { type: 'number', value: responseTokenCount };
} else if (useMessageApi) {
// Use the messages API for Claude 3 models
Expand All @@ -464,16 +479,17 @@ export const ChatAnthropicNodeImpl: PluginNodeImpl<ChatAnthropicNode> = {
signal: context.signal,
...messageOptions,
});

// Process the response chunks and update the output
const responseParts: string[] = [];
let requestTokens: number | undefined = undefined, responseTokens: number | undefined = undefined;
let requestTokens: number | undefined = undefined,
responseTokens: number | undefined = undefined;
for await (const chunk of chunks) {
let completion: string = '';
if (chunk.type === 'content_block_start') {
completion = chunk.content_block.text;
} else if (chunk.type === 'content_block_delta') {
completion = chunk.delta.text;
completion = chunk.delta.text;
} else if (chunk.type === 'message_start' && chunk.message?.usage?.input_tokens) {
requestTokens = chunk.message.usage.input_tokens;
} else if (chunk.type === 'message_delta' && chunk.delta?.usage?.output_tokens) {
Expand All @@ -496,17 +512,18 @@ export const ChatAnthropicNodeImpl: PluginNodeImpl<ChatAnthropicNode> = {
message: responseParts.join('').trim(),
function_call: undefined,
} satisfies ChatMessage,
]
],
};
context.onPartialOutputs?.(output);
}

if (responseParts.length === 0) {
throw new Error('No response from Anthropic');
}

output['requestTokens' as PortId] = { type: 'number', value: requestTokens ?? tokenCountEstimate };
const responseTokenCount = responseTokens ?? context.tokenizer.getTokenCountForString(responseParts.join(''), tokenizerInfo);
const responseTokenCount =
responseTokens ?? context.tokenizer.getTokenCountForString(responseParts.join(''), tokenizerInfo);
output['responseTokens' as PortId] = { type: 'number', value: responseTokenCount };
} else {
// Use the normal chat completion method for non-Claude 3 models
Expand All @@ -515,7 +532,7 @@ export const ChatAnthropicNodeImpl: PluginNodeImpl<ChatAnthropicNode> = {
signal: context.signal,
...completionOptions,
});

// Process the response chunks and update the output
const responseParts: string[] = [];
for await (const chunk of chunks) {
Expand All @@ -533,7 +550,7 @@ export const ChatAnthropicNodeImpl: PluginNodeImpl<ChatAnthropicNode> = {
if (responseParts.length === 0) {
throw new Error('No response from Anthropic');
}

output['all-messages' as PortId] = {
type: 'chat-message[]',
value: [
Expand All @@ -543,23 +560,26 @@ export const ChatAnthropicNodeImpl: PluginNodeImpl<ChatAnthropicNode> = {
message: responseParts.join('').trim(),
function_call: undefined,
} satisfies ChatMessage,
]
],
};
output['requestTokens' as PortId] = { type: 'number', value: tokenCountEstimate };
const responseTokenCount = context.tokenizer.getTokenCountForString(responseParts.join(''), tokenizerInfo);
output['responseTokens' as PortId] = { type: 'number', value: responseTokenCount };
}

const cost = getCostForTokens({
requestTokens: output['requestTokens' as PortId]?.value as number,
responseTokens: output['responseTokens' as PortId]?.value as number,
}, model);
const cost = getCostForTokens(
{
requestTokens: output['requestTokens' as PortId]?.value as number,
responseTokens: output['responseTokens' as PortId]?.value as number,
},
model,
);
if (cost != null) {
output['cost' as PortId] = { type: 'number', value: cost };
}

const endTime = Date.now();

const duration = endTime - startTime;
output['duration' as PortId] = { type: 'number', value: duration };

Expand Down Expand Up @@ -626,7 +646,7 @@ function getChatMessages(inputs: Inputs) {
if (!prompt) {
throw new Error('Prompt is required');
}

const chatMessages = match(prompt)
.with({ type: 'chat-message' }, (p) => [p.value])
.with({ type: 'chat-message[]' }, (p) => p.value)
Expand All @@ -643,24 +663,26 @@ function getChatMessages(inputs: Inputs) {
'string',
),
);

return stringValues.filter((v) => v != null).map((v) => ({ type: 'user', message: v }));
}

const coercedMessage = coerceType(p, 'chat-message');
if (coercedMessage != null) {
return [coercedMessage];
}

const coercedString = coerceType(p, 'string');
return coercedString != null ? [{ type: 'user', message: coerceType(p, 'string') }] : [];
});

return chatMessages;
}
}

export async function chatMessagesToClaude3ChatMessages(chatMessages: ChatMessage[]): Promise<Claude3ChatMessage[]> {
const messages: Claude3ChatMessage[] = (await Promise.all(chatMessages.map(chatMessageToClaude3ChatMessage))).filter(isNotNull);
const messages: Claude3ChatMessage[] = (await Promise.all(chatMessages.map(chatMessageToClaude3ChatMessage))).filter(
isNotNull,
);

return messages;
}
Expand All @@ -671,17 +693,23 @@ async function chatMessageToClaude3ChatMessage(message: ChatMessage): Promise<Cl
}
if (message.type === 'function') {
// Interpret function messages as user messages with tool_result content items (making Claude API more similar to OpenAI's)
const content = (Array.isArray(message.message) ? message.message : [message.message]).map((m) => typeof m === 'string' ? { type: 'text' as const, text: m } : undefined).filter(isNotNull);
const content = (Array.isArray(message.message) ? message.message : [message.message])
.map((m) => (typeof m === 'string' ? { type: 'text' as const, text: m } : undefined))
.filter(isNotNull);
return {
role: 'user',
content: [{
type: 'tool_result',
tool_use_id: message.name,
content: content.length === 1 ? content[0]!.text : content,
}],
content: [
{
type: 'tool_result',
tool_use_id: message.name,
content: content.length === 1 ? content[0]!.text : content,
},
],
};
}
const content = Array.isArray(message.message) ? await Promise.all(message.message.map(chatMessageContentToClaude3ChatMessage)) : [await chatMessageContentToClaude3ChatMessage(message.message)];
const content = Array.isArray(message.message)
? await Promise.all(message.message.map(chatMessageContentToClaude3ChatMessage))
: [await chatMessageContentToClaude3ChatMessage(message.message)];
if (message.type === 'assistant' && message.function_call) {
content.push({
type: 'tool_use',
Expand All @@ -696,7 +724,9 @@ async function chatMessageToClaude3ChatMessage(message: ChatMessage): Promise<Cl
};
}

async function chatMessageContentToClaude3ChatMessage(content: ChatMessageMessagePart): Promise<Claude3ChatMessageContentPart> {
async function chatMessageContentToClaude3ChatMessage(
content: ChatMessageMessagePart,
): Promise<Claude3ChatMessageContentPart> {
if (typeof content === 'string') {
return {
type: 'text',
Expand All @@ -720,11 +750,13 @@ async function chatMessageContentToClaude3ChatMessage(content: ChatMessageMessag
}
}

function getCostForTokens(tokenCounts: { requestTokens: number; responseTokens: number; }, model: AnthropicModels): number | undefined {
function getCostForTokens(
tokenCounts: { requestTokens: number; responseTokens: number },
model: AnthropicModels,
): number | undefined {
const modelInfo = anthropicModels[model];
if (modelInfo == null) {
return undefined;
}
return modelInfo.cost.prompt * tokenCounts.requestTokens + modelInfo.cost.completion * tokenCounts.responseTokens;
}

0 comments on commit 1f61c1e

Please sign in to comment.