diff --git a/libs/langchain-anthropic/src/chat_models.ts b/libs/langchain-anthropic/src/chat_models.ts index a0a0aae6333c..a9fcbb118944 100644 --- a/libs/langchain-anthropic/src/chat_models.ts +++ b/libs/langchain-anthropic/src/chat_models.ts @@ -160,10 +160,18 @@ function _makeMessageChunkFromAnthropicEvent( streamUsage: boolean; coerceContentToString: boolean; usageData: { input_tokens: number; output_tokens: number }; + toolUse?: { + id: string; + name: string; + }; } ): { chunk: AIMessageChunk; usageData: { input_tokens: number; output_tokens: number }; + toolUse?: { + id: string; + name: string; + }; } | null { let usageDataCopy = { ...fields.usageData }; @@ -233,6 +241,10 @@ function _makeMessageChunkFromAnthropicEvent( additional_kwargs: {}, }), usageData: usageDataCopy, + toolUse: { + id: data.content_block.id, + name: data.content_block.name, + }, }; } else if ( data.type === "content_block_delta" && @@ -274,6 +286,25 @@ function _makeMessageChunkFromAnthropicEvent( }), usageData: usageDataCopy, }; + } else if (data.type === "content_block_stop" && fields.toolUse) { + // Only yield the ID & name when the tool_use block is complete. + // This is so the names & IDs do not get concatenated. + return { + chunk: new AIMessageChunk({ + content: fields.coerceContentToString + ? "" + : [ + { + id: fields.toolUse.id, + name: fields.toolUse.name, + index: data.index, + type: "input_json_delta", + }, + ], + additional_kwargs: {}, + }), + usageData: usageDataCopy, + }; } return null; @@ -424,6 +455,9 @@ export function _convertLangChainToolCallToAnthropic( } function _formatContent(content: MessageContent) { + const toolTypes = ["tool_use", "tool_result", "input_json_delta"]; + const textTypes = ["text", "text_delta"]; + if (typeof content === "string") { return content; } else { @@ -439,19 +473,40 @@ function _formatContent(content: MessageContent) { type: "image" as const, // Explicitly setting the type as "image" source, }; - } else if (contentPart.type === "text") { + } else if ( + textTypes.find((t) => t === contentPart.type) && + "text" in contentPart + ) { // Assuming contentPart is of type MessageContentText here return { type: "text" as const, // Explicitly setting the type as "text" text: contentPart.text, }; - } else if ( - contentPart.type === "tool_use" || - contentPart.type === "tool_result" - ) { + } else if (toolTypes.find((t) => t === contentPart.type)) { + const contentPartCopy = { ...contentPart }; + if ("index" in contentPartCopy) { + // Anthropic does not support passing the index field here, so we remove it. + delete contentPartCopy.index; + } + + if (contentPartCopy.type === "input_json_delta") { + // `input_json_delta` type only represents yielding partial tool inputs + // and is not a valid type for Anthropic messages. + contentPartCopy.type = "tool_use"; + } + + if ("input" in contentPartCopy) { + // Anthropic tool use inputs should be valid objects, when applicable. + try { + contentPartCopy.input = JSON.parse(contentPartCopy.input); + } catch { + // no-op + } + } + // TODO: Fix when SDK types are fixed return { - ...contentPart, + ...contentPartCopy, // eslint-disable-next-line @typescript-eslint/no-explicit-any } as any; } else { @@ -519,7 +574,9 @@ function _formatMessagesForAnthropic(messages: BaseMessage[]): { const hasMismatchedToolCalls = !message.tool_calls.every((toolCall) => content.find( (contentPart) => - contentPart.type === "tool_use" && contentPart.id === toolCall.id + (contentPart.type === "tool_use" || + contentPart.type === "input_json_delta") && + contentPart.id === toolCall.id ) ); if (hasMismatchedToolCalls) { @@ -581,12 +638,16 @@ function extractToolCallChunk( ) { if (typeof inputJsonDeltaChunks.input === "string") { newToolCallChunk = { + id: inputJsonDeltaChunks.id, + name: inputJsonDeltaChunks.name, args: inputJsonDeltaChunks.input, index: inputJsonDeltaChunks.index, type: "tool_call_chunk", }; } else { newToolCallChunk = { + id: inputJsonDeltaChunks.id, + name: inputJsonDeltaChunks.name, args: JSON.stringify(inputJsonDeltaChunks.input, null, 2), index: inputJsonDeltaChunks.index, type: "tool_call_chunk", @@ -919,6 +980,14 @@ export class ChatAnthropicMessages< let usageData = { input_tokens: 0, output_tokens: 0 }; let concatenatedChunks: AIMessageChunk | undefined; + // Anthropic only yields the tool name and id once, so we need to save those + // so we can yield them with the rest of the tool_use content. + let toolUse: + | { + id: string; + name: string; + } + | undefined; for await (const data of stream) { if (options.signal?.aborted) { @@ -930,12 +999,27 @@ export class ChatAnthropicMessages< streamUsage: !!(this.streamUsage || options.streamUsage), coerceContentToString, usageData, + toolUse: toolUse + ? { + id: toolUse.id, + name: toolUse.name, + } + : undefined, }); if (!result) continue; - const { chunk, usageData: updatedUsageData } = result; + const { + chunk, + usageData: updatedUsageData, + toolUse: updatedToolUse, + } = result; + usageData = updatedUsageData; + if (updatedToolUse) { + toolUse = updatedToolUse; + } + const newToolCallChunk = extractToolCallChunk(chunk); // Maintain concatenatedChunks for accessing the complete `tool_use` content block. concatenatedChunks = concatenatedChunks diff --git a/libs/langchain-standard-tests/src/integration_tests/chat_models.ts b/libs/langchain-standard-tests/src/integration_tests/chat_models.ts index f3151ed32b59..ac73ef8631a6 100644 --- a/libs/langchain-standard-tests/src/integration_tests/chat_models.ts +++ b/libs/langchain-standard-tests/src/integration_tests/chat_models.ts @@ -10,10 +10,11 @@ import { getBufferString, } from "@langchain/core/messages"; import { z } from "zod"; -import { StructuredTool } from "@langchain/core/tools"; +import { StructuredTool, tool } from "@langchain/core/tools"; import { zodToJsonSchema } from "zod-to-json-schema"; import { ChatPromptTemplate } from "@langchain/core/prompts"; import { RunnableLambda } from "@langchain/core/runnables"; +import { concat } from "@langchain/core/utils/stream"; import { BaseChatModelsTests, BaseChatModelsTestsFields, @@ -522,6 +523,159 @@ export abstract class ChatModelIntegrationTests< expect(cacheValue2).toEqual(cacheValue); } + /** + * This test verifies models can invoke a tool, and use the AIMessage + * with the tool call in a followup request. This is useful when building + * agents, or other pipelines that invoke tools. + */ + async testModelCanUseToolUseAIMessage() { + if (!this.chatModelHasToolCalling) { + console.log("Test requires tool calling. Skipping..."); + return; + } + + const model = new this.Cls(this.constructorArgs); + if (!model.bindTools) { + throw new Error( + "bindTools undefined. Cannot test OpenAI formatted tool calls." + ); + } + + const weatherSchema = z.object({ + location: z.string().describe("The location to get the weather for."), + }); + + // Define the tool + const weatherTool = tool( + (_) => "The weather in San Francisco is 70 degrees and sunny.", + { + name: "get_current_weather", + schema: weatherSchema, + description: "Get the current weather for a location.", + } + ); + + const modelWithTools = model.bindTools([weatherTool]); + + // List of messages to initially invoke the model with, and to hold + // followup messages to invoke the model with. + const messages = [ + new HumanMessage( + "What's the weather like in San Francisco right now? Use the 'get_current_weather' tool to find the answer." + ), + ]; + + const result: AIMessage = await modelWithTools.invoke(messages); + + expect(result.tool_calls?.[0]).toBeDefined(); + if (!result.tool_calls?.[0]) { + throw new Error("result.tool_calls is undefined"); + } + const { tool_calls } = result; + expect(tool_calls[0].name).toBe("get_current_weather"); + + // Push the result of the tool call into the messages array so we can + // confirm in the followup request the model can use the tool call. + messages.push(result); + + // Create a dummy ToolMessage representing the output of the tool call. + const toolMessage = new ToolMessage({ + tool_call_id: tool_calls[0].id ?? "", + name: tool_calls[0].name, + content: await weatherTool.invoke( + tool_calls[0].args as z.infer + ), + }); + messages.push(toolMessage); + + const finalResult = await modelWithTools.invoke(messages); + + expect(finalResult.content).not.toBe(""); + } + + /** + * Same as the above test, but streaming both model invocations. + */ + async testModelCanUseToolUseAIMessageWithStreaming() { + if (!this.chatModelHasToolCalling) { + console.log("Test requires tool calling. Skipping..."); + return; + } + + const model = new this.Cls(this.constructorArgs); + if (!model.bindTools) { + throw new Error( + "bindTools undefined. Cannot test OpenAI formatted tool calls." + ); + } + + const weatherSchema = z.object({ + location: z.string().describe("The location to get the weather for."), + }); + + // Define the tool + const weatherTool = tool( + (_) => "The weather in San Francisco is 70 degrees and sunny.", + { + name: "get_current_weather", + schema: weatherSchema, + description: "Get the current weather for a location.", + } + ); + + const modelWithTools = model.bindTools([weatherTool]); + + // List of messages to initially invoke the model with, and to hold + // followup messages to invoke the model with. + const messages = [ + new HumanMessage( + "What's the weather like in San Francisco right now? Use the 'get_current_weather' tool to find the answer." + ), + ]; + + const stream = await modelWithTools.stream(messages); + let result: AIMessageChunk | undefined; + for await (const chunk of stream) { + result = !result ? chunk : concat(result, chunk); + } + + expect(result).toBeDefined(); + if (!result) return; + + expect(result.tool_calls?.[0]).toBeDefined(); + if (!result.tool_calls?.[0]) { + throw new Error("result.tool_calls is undefined"); + } + + const { tool_calls } = result; + expect(tool_calls[0].name).toBe("get_current_weather"); + + // Push the result of the tool call into the messages array so we can + // confirm in the followup request the model can use the tool call. + messages.push(result); + + // Create a dummy ToolMessage representing the output of the tool call. + const toolMessage = new ToolMessage({ + tool_call_id: tool_calls[0].id ?? "", + name: tool_calls[0].name, + content: await weatherTool.invoke( + tool_calls[0].args as z.infer + ), + }); + messages.push(toolMessage); + + const finalStream = await modelWithTools.stream(messages); + let finalResult: AIMessageChunk | undefined; + for await (const chunk of finalStream) { + finalResult = !finalResult ? chunk : concat(finalResult, chunk); + } + + expect(finalResult).toBeDefined(); + if (!finalResult) return; + + expect(finalResult.content).not.toBe(""); + } + /** * Run all unit tests for the chat model. * Each test is wrapped in a try/catch block to prevent the entire test suite from failing. @@ -629,6 +783,20 @@ export abstract class ChatModelIntegrationTests< console.error("testCacheComplexMessageTypes failed", e); } + try { + await this.testModelCanUseToolUseAIMessage(); + } catch (e: any) { + allTestsPassed = false; + console.error("testModelCanUseToolUseAIMessage failed", e); + } + + try { + await this.testModelCanUseToolUseAIMessageWithStreaming(); + } catch (e: any) { + allTestsPassed = false; + console.error("testModelCanUseToolUseAIMessageWithStreaming failed", e); + } + return allTestsPassed; } }