Skip to content

Commit

Permalink
anthropic[patch]: Fix passing streamed tool calls back to anthropic
Browse files Browse the repository at this point in the history
  • Loading branch information
bracesproul committed Jul 24, 2024
1 parent 636c17f commit dc574f6
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 6 deletions.
92 changes: 86 additions & 6 deletions libs/langchain-anthropic/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,18 @@ function _makeMessageChunkFromAnthropicEvent(
streamUsage: boolean;
coerceContentToString: boolean;
usageData: { input_tokens: number; output_tokens: number };
toolUse?: {
id: string;
name: string;
};
}
): {
chunk: AIMessageChunk;
usageData: { input_tokens: number; output_tokens: number };
toolUse?: {
id: string;
name: string;
};
} | null {
let usageDataCopy = { ...fields.usageData };

Expand Down Expand Up @@ -233,6 +241,10 @@ function _makeMessageChunkFromAnthropicEvent(
additional_kwargs: {},
}),
usageData: usageDataCopy,
toolUse: {
id: data.content_block.id,
name: data.content_block.name,
},
};
} else if (
data.type === "content_block_delta" &&
Expand Down Expand Up @@ -274,6 +286,25 @@ function _makeMessageChunkFromAnthropicEvent(
}),
usageData: usageDataCopy,
};
} else if (data.type === "content_block_stop" && fields.toolUse) {
// Only yield the ID & name when the tool_use block is complete.
// This is so the names & IDs do not get concatenated.
return {
chunk: new AIMessageChunk({
content: fields.coerceContentToString
? ""
: [
{
id: fields.toolUse.id,
name: fields.toolUse.name,
index: data.index,
type: "input_json_delta",
},
],
additional_kwargs: {},
}),
usageData: usageDataCopy,
};
}

return null;
Expand Down Expand Up @@ -424,6 +455,9 @@ export function _convertLangChainToolCallToAnthropic(
}

function _formatContent(content: MessageContent) {
const toolTypes = ["tool_use", "tool_result", "input_json_delta"];
const textTypes = ["text", "text_delta"];

if (typeof content === "string") {
return content;
} else {
Expand All @@ -439,16 +473,34 @@ function _formatContent(content: MessageContent) {
type: "image" as const, // Explicitly setting the type as "image"
source,
};
} else if (contentPart.type === "text") {
} else if (textTypes.find((t) => t === contentPart.type) && "text" in contentPart) {
// Assuming contentPart is of type MessageContentText here
return {
type: "text" as const, // Explicitly setting the type as "text"
text: contentPart.text,
};
} else if (
contentPart.type === "tool_use" ||
contentPart.type === "tool_result"
toolTypes.find((t) => t === contentPart.type)
) {
if ("index" in contentPart) {
// Anthropic does not support passing the index field here, so we remove it
delete contentPart.index;

Check failure on line 487 in libs/langchain-anthropic/src/chat_models.ts

View workflow job for this annotation

GitHub Actions / Check linting

Assignment to property of function parameter 'contentPart'
}

if (contentPart.type === "input_json_delta") {
// If type is `input_json_delta`, rename to `tool_use` for Anthropic
contentPart.type = "tool_use";

Check failure on line 492 in libs/langchain-anthropic/src/chat_models.ts

View workflow job for this annotation

GitHub Actions / Check linting

Assignment to property of function parameter 'contentPart'
}

if ("input" in contentPart) {
// If the input is a JSON string, attempt to parse it
try {
contentPart.input = JSON.parse(contentPart.input);

Check failure on line 498 in libs/langchain-anthropic/src/chat_models.ts

View workflow job for this annotation

GitHub Actions / Check linting

Assignment to property of function parameter 'contentPart'
} catch {
// no-op
}
}

// TODO: Fix when SDK types are fixed
return {
...contentPart,
Expand Down Expand Up @@ -519,7 +571,9 @@ function _formatMessagesForAnthropic(messages: BaseMessage[]): {
const hasMismatchedToolCalls = !message.tool_calls.every((toolCall) =>
content.find(
(contentPart) =>
contentPart.type === "tool_use" && contentPart.id === toolCall.id
(contentPart.type === "tool_use" ||
contentPart.type === "input_json_delta") &&
contentPart.id === toolCall.id
)
);
if (hasMismatchedToolCalls) {
Expand Down Expand Up @@ -581,12 +635,14 @@ function extractToolCallChunk(
) {
if (typeof inputJsonDeltaChunks.input === "string") {
newToolCallChunk = {
id: inputJsonDeltaChunks.id,
args: inputJsonDeltaChunks.input,
index: inputJsonDeltaChunks.index,
type: "tool_call_chunk",
};
} else {
newToolCallChunk = {
id: inputJsonDeltaChunks.id,
args: JSON.stringify(inputJsonDeltaChunks.input, null, 2),
index: inputJsonDeltaChunks.index,
type: "tool_call_chunk",
Expand Down Expand Up @@ -919,6 +975,14 @@ export class ChatAnthropicMessages<
let usageData = { input_tokens: 0, output_tokens: 0 };

let concatenatedChunks: AIMessageChunk | undefined;
// Anthropic only yields the tool name and id once, so we need to save those
// so we can yield them with the rest of the tool_use content.
let toolUse:
| {
id: string;
name: string;
}
| undefined;

for await (const data of stream) {
if (options.signal?.aborted) {
Expand All @@ -930,12 +994,25 @@ export class ChatAnthropicMessages<
streamUsage: !!(this.streamUsage || options.streamUsage),
coerceContentToString,
usageData,
toolUse: toolUse ? {
id: toolUse.id,
name: toolUse.name,
} : undefined,
});
if (!result) continue;

const { chunk, usageData: updatedUsageData } = result;
const {
chunk,
usageData: updatedUsageData,
toolUse: updatedToolUse,
} = result;

usageData = updatedUsageData;

if (updatedToolUse) {
toolUse = updatedToolUse;
}

const newToolCallChunk = extractToolCallChunk(chunk);
// Maintain concatenatedChunks for accessing the complete `tool_use` content block.
concatenatedChunks = concatenatedChunks
Expand Down Expand Up @@ -1015,11 +1092,14 @@ export class ChatAnthropicMessages<
},
}
: requestOptions;
const formattedMsgs = _formatMessagesForAnthropic(messages);
console.log("formattedMsgs");
console.dir(formattedMsgs, { depth: null });
const response = await this.completionWithRetry(
{
...params,
stream: false,
..._formatMessagesForAnthropic(messages),
...formattedMsgs,
},
options
);
Expand Down
41 changes: 41 additions & 0 deletions libs/langchain-anthropic/src/tests/chat_models-tools.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -440,3 +440,44 @@ test("llm token callbacks can handle tool calls", async () => {
if (!args) return;
expect(args).toEqual(JSON.parse(tokens));
});

test.only("Anthropic can stream tool calls, and invoke again with that tool call", async () => {

Check failure on line 444 in libs/langchain-anthropic/src/tests/chat_models-tools.int.test.ts

View workflow job for this annotation

GitHub Actions / Check linting

Unexpected focused test
const input = [
new HumanMessage("What is the weather in SF?"),
];

const weatherTool = tool(
(_) => "The weather in San Francisco is 25°C",
{
name: "get_weather",
description: zodSchema.description,
schema: zodSchema,
}
);

const modelWithTools = model.bindTools([weatherTool]);

const stream = await modelWithTools.stream(input);

let finalChunk: AIMessageChunk | undefined;
for await (const chunk of stream) {
finalChunk = !finalChunk ? chunk : concat(finalChunk, chunk);
}
if (!finalChunk) {
throw new Error("chunk not defined");
}
// Push the AI message with the tool call to the input array.
input.push(finalChunk);
// Push a ToolMessage to the input array to represent the tool call response.
input.push(
new ToolMessage({
tool_call_id: finalChunk.tool_calls?.[0].id ?? "",
content:
"The weather in San Francisco is currently 25 degrees and sunny.",
name: "get_weather",
})
);
// Invoke again to ensure Anthropic can handle it's own tool call.
const finalResult = await modelWithTools.invoke(input);
console.dir(finalResult, { depth: null });
});

0 comments on commit dc574f6

Please sign in to comment.