From 864905efdd188590cfeac9b04bca5209ce143222 Mon Sep 17 00:00:00 2001 From: bracesproul Date: Tue, 11 Jun 2024 15:30:10 -0700 Subject: [PATCH] implement .bind, fix call options, add docs --- .../docs/integrations/chat/bedrock.mdx | 28 ++++++ .../models/chat/integration_bedrock_tools.ts | 64 +++++++++++++ .../models/chat/integration_bedrock_wso.ts | 35 ++++++++ .../src/chat_models/bedrock/web.ts | 90 ++++++++++++------- .../chat_models/tests/chatbedrock.int.test.ts | 54 +++++++++-- .../src/output_parsers/bedrock.ts | 4 +- 6 files changed, 237 insertions(+), 38 deletions(-) create mode 100644 examples/src/models/chat/integration_bedrock_tools.ts create mode 100644 examples/src/models/chat/integration_bedrock_wso.ts diff --git a/docs/core_docs/docs/integrations/chat/bedrock.mdx b/docs/core_docs/docs/integrations/chat/bedrock.mdx index ba5ce7b96ed1..10cdc57c9d98 100644 --- a/docs/core_docs/docs/integrations/chat/bedrock.mdx +++ b/docs/core_docs/docs/integrations/chat/bedrock.mdx @@ -56,3 +56,31 @@ Anthropic Claude-3 models hosted on Bedrock have multimodal capabilities and can import BedrockMultimodalExample from "@examples/models/chat/integration_bedrock_multimodal.ts"; {BedrockMultimodalExample} + +### Tool calling + +:::info +Not all Bedrock models support tool calling. Please refer to the [model documentation](https://docs.aws.amazon.com/bedrock/latest/APIReference/welcome.html) for more information. +::: + +The examples below demonstrate how to use tool calling, along with the `withStructuredOutput` method to easily compose structured output LLM calls. + +import ToolCalling from "@examples/models/chat/integration_bedrock_tools.ts"; + +{ToolCalling} + +:::tip +See the LangSmith trace [here](https://smith.langchain.com/public/003a684d-90eb-406e-a146-8ee5e617921b/r) +::: + +#### `.withStructuredOutput({ ... })` + +Using the `.withStructuredOutput` method, you can easily make the LLM return structured output, given only a Zod or JSON schema: + +import WSOExample from "@examples/models/chat/integration_bedrock_wso.ts"; + +{WSOExample} + +:::tip +See the LangSmith trace [here](https://smith.langchain.com/public/1f7b1ad8-e4ac-4965-8ce1-fae06005f3d7/r) +::: diff --git a/examples/src/models/chat/integration_bedrock_tools.ts b/examples/src/models/chat/integration_bedrock_tools.ts new file mode 100644 index 000000000000..62faa21865cc --- /dev/null +++ b/examples/src/models/chat/integration_bedrock_tools.ts @@ -0,0 +1,64 @@ + +import { BedrockChat } from "@langchain/community/chat_models/bedrock"; +// Or, from web environments: +// import { BedrockChat } from "@langchain/community/chat_models/bedrock/web"; +import { z } from "zod"; +import { zodToJsonSchema } from "zod-to-json-schema"; + +const model = new BedrockChat({ + region: process.env.BEDROCK_AWS_REGION, + model: "anthropic.claude-3-sonnet-20240229-v1:0", + maxRetries: 0, + credentials: { + secretAccessKey: process.env.BEDROCK_AWS_SECRET_ACCESS_KEY!, + accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID!, + }, +}); + +const weatherSchema = z + .object({ + city: z.string().describe("The city to get the weather for"), + state: z.string().describe("The state to get the weather for").optional(), + }) + .describe("Get the weather for a city"); + +const modelWithTools = model.bindTools([ + { + name: "weather_tool", + description: weatherSchema.description, + input_schema: zodToJsonSchema(weatherSchema), + }, +]); +// Optionally, you can bind tools via the `.bind` method: +// const modelWithTools = model.bind({ +// tools: [ +// { +// name: "weather_tool", +// description: weatherSchema.description, +// input_schema: zodToJsonSchema(weatherSchema), +// }, +// ], +// }); + +const res = await modelWithTools.invoke("What's the weather in New York?"); +console.log(res); + +/* +AIMessage { + additional_kwargs: { id: 'msg_bdrk_01JF7hb4PNQPywP4gnBbgpHi' }, + response_metadata: { + stop_reason: 'tool_use', + usage: { input_tokens: 300, output_tokens: 85 } + }, + tool_calls: [ + { + name: 'weather_tool', + args: { + city: 'New York', + state: 'NY' + }, + id: 'toolu_bdrk_01AtEZRTCKioFXqhoNcpgaV7' + } + ], +} +*/ diff --git a/examples/src/models/chat/integration_bedrock_wso.ts b/examples/src/models/chat/integration_bedrock_wso.ts new file mode 100644 index 000000000000..5a244d231709 --- /dev/null +++ b/examples/src/models/chat/integration_bedrock_wso.ts @@ -0,0 +1,35 @@ + +import { BedrockChat } from "@langchain/community/chat_models/bedrock"; +// Or, from web environments: +// import { BedrockChat } from "@langchain/community/chat_models/bedrock/web"; +import { z } from "zod"; + +const model = new BedrockChat({ + region: process.env.BEDROCK_AWS_REGION, + model: "anthropic.claude-3-sonnet-20240229-v1:0", + maxRetries: 0, + credentials: { + secretAccessKey: process.env.BEDROCK_AWS_SECRET_ACCESS_KEY!, + accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID!, + }, +}); + +const weatherSchema = z + .object({ + city: z.string().describe("The city to get the weather for"), + state: z.string().describe("The state to get the weather for").optional(), + }) + .describe("Get the weather for a city"); + +const modelWithStructuredOutput = model.withStructuredOutput(weatherSchema, { + name: "weather_tool", // Optional, defaults to 'extract' +}); + +const res = await modelWithStructuredOutput.invoke( + "What's the weather in New York?" +); +console.log(res); + +/* +{ city: 'New York', state: 'NY' } +*/ diff --git a/libs/langchain-community/src/chat_models/bedrock/web.ts b/libs/langchain-community/src/chat_models/bedrock/web.ts index 2b20728332b0..7c018e14d719 100644 --- a/libs/langchain-community/src/chat_models/bedrock/web.ts +++ b/libs/langchain-community/src/chat_models/bedrock/web.ts @@ -112,6 +112,29 @@ export function convertMessagesToPrompt( throw new Error(`Provider ${provider} does not support chat.`); } +function formatTools( + tools: (StructuredToolInterface | AnthropicTool)[] +): AnthropicTool[] { + return tools.map((tool) => { + if (isStructuredTool(tool)) { + return { + name: tool.name, + description: tool.description, + input_schema: zodToJsonSchema(tool.schema), + }; + } + return tool; + }); +} + +export interface BedrockChatCallOptions extends BaseChatModelCallOptions { + tools?: (StructuredToolInterface | AnthropicTool)[]; +} + +export interface BedrockChatFields + extends Partial, + BaseChatModelParams {} + /** * A type of Large Language Model (LLM) that interacts with the Bedrock * service. It extends the base `LLM` class and implements the @@ -208,7 +231,10 @@ export function convertMessagesToPrompt( * runStreaming().catch(console.error); * ``` */ -export class BedrockChat extends BaseChatModel implements BaseBedrockInput { +export class BedrockChat + extends BaseChatModel + implements BaseBedrockInput +{ model = "amazon.titan-tg1-large"; region: string; @@ -281,7 +307,7 @@ export class BedrockChat extends BaseChatModel implements BaseBedrockInput { return "BedrockChat"; } - constructor(fields?: Partial & BaseChatModelParams) { + constructor(fields?: BedrockChatFields) { super(fields ?? {}); this.model = fields?.model ?? this.model; @@ -331,11 +357,14 @@ export class BedrockChat extends BaseChatModel implements BaseBedrockInput { } override invocationParams(options?: this["ParsedCallOptions"]) { + const callOptionTools = formatTools(options?.tools ?? []); return { - tools: this._anthropicTools, + tools: [...(this._anthropicTools ?? []), ...callOptionTools], temperature: this.temperature, max_tokens: this.maxTokens, - stop: options?.stop, + stop: options?.stop ?? this.stopSequences, + modelKwargs: this.modelKwargs, + guardrailConfig: this.guardrailConfig, }; } @@ -353,7 +382,7 @@ export class BedrockChat extends BaseChatModel implements BaseBedrockInput { async _generate( messages: BaseMessage[], - options: Partial, + options: Partial, runManager?: CallbackManagerForLLMRun ): Promise { if (this.streaming) { @@ -381,7 +410,7 @@ export class BedrockChat extends BaseChatModel implements BaseBedrockInput { async _generateNonStreaming( messages: BaseMessage[], - options: Partial, + options: Partial, _runManager?: CallbackManagerForLLMRun ): Promise { const service = "bedrock-runtime"; @@ -425,26 +454,34 @@ export class BedrockChat extends BaseChatModel implements BaseBedrockInput { } ) { const { bedrockMethod, endpointHost, provider } = fields; + const { + max_tokens, + temperature, + stop, + modelKwargs, + guardrailConfig, + tools, + } = this.invocationParams(options); const inputBody = this.usesMessagesApi ? BedrockLLMInputOutputAdapter.prepareMessagesInput( provider, messages, - this.maxTokens, - this.temperature, - options.stop ?? this.stopSequences, - this.modelKwargs, - this.guardrailConfig, - this._anthropicTools + max_tokens, + temperature, + stop, + modelKwargs, + guardrailConfig, + tools ) : BedrockLLMInputOutputAdapter.prepareInput( provider, convertMessagesToPromptAnthropic(messages), - this.maxTokens, - this.temperature, - options.stop ?? this.stopSequences, - this.modelKwargs, + max_tokens, + temperature, + stop, + modelKwargs, fields.bedrockMethod, - this.guardrailConfig + guardrailConfig ); const url = new URL( @@ -694,11 +731,11 @@ export class BedrockChat extends BaseChatModel implements BaseBedrockInput { override bindTools( tools: (StructuredToolInterface | AnthropicTool)[], - _kwargs?: Partial + _kwargs?: Partial ): Runnable< BaseLanguageModelInput, BaseMessageChunk, - BaseChatModelCallOptions + this["ParsedCallOptions"] > { const provider = this.model.split(".")[0]; if (provider !== "anthropic") { @@ -706,16 +743,7 @@ export class BedrockChat extends BaseChatModel implements BaseBedrockInput { "Currently, tool calling through Bedrock is only supported for Anthropic models." ); } - this._anthropicTools = tools.map((tool) => { - if (isStructuredTool(tool)) { - return { - name: tool.name, - description: tool.description, - input_schema: zodToJsonSchema(tool.schema), - }; - } - return tool; - }); + this._anthropicTools = formatTools(tools); return this; } @@ -762,7 +790,9 @@ export class BedrockChat extends BaseChatModel implements BaseBedrockInput { const method = config?.method; const includeRaw = config?.includeRaw; if (method === "jsonMode") { - throw new Error(`Anthropic only supports "functionCalling" as a method.`); + throw new Error( + `BedrockChat only supports "functionCalling" as a method.` + ); } let functionName = name ?? "extract"; diff --git a/libs/langchain-community/src/chat_models/tests/chatbedrock.int.test.ts b/libs/langchain-community/src/chat_models/tests/chatbedrock.int.test.ts index 11faf5b050c3..c3a38a44643c 100644 --- a/libs/langchain-community/src/chat_models/tests/chatbedrock.int.test.ts +++ b/libs/langchain-community/src/chat_models/tests/chatbedrock.int.test.ts @@ -6,9 +6,10 @@ import { test, expect } from "@jest/globals"; import { HumanMessage } from "@langchain/core/messages"; import { AgentExecutor, createToolCallingAgent } from "langchain/agents"; import { ChatPromptTemplate } from "@langchain/core/prompts"; +import { z } from "zod"; +import { zodToJsonSchema } from "zod-to-json-schema"; import { BedrockChat as BedrockChatWeb } from "../bedrock/web.js"; import { TavilySearchResults } from "../../tools/tavily_search.js"; -import { z } from "zod"; void testChatModel( "Test Bedrock chat model Generating search queries: Command-r", @@ -386,10 +387,12 @@ test.skip.each([ }); test.skip("withStructuredOutput", async () => { - const weatherTool = z.object({ - city: z.string().describe("The city to get the weather for"), - state: z.string().describe("The state to get the weather for").optional(), - }).describe("Get the weather for a city"); + const weatherTool = z + .object({ + city: z.string().describe("The city to get the weather for"), + state: z.string().describe("The state to get the weather for").optional(), + }) + .describe("Get the weather for a city"); const model = new BedrockChatWeb({ region: process.env.BEDROCK_AWS_REGION, model: "anthropic.claude-3-sonnet-20240229-v1:0", @@ -402,6 +405,43 @@ test.skip("withStructuredOutput", async () => { const modelWithTools = model.withStructuredOutput(weatherTool, { name: "weather", }); - const response = await modelWithTools.invoke("Whats the weather like in san francisco?"); + const response = await modelWithTools.invoke( + "Whats the weather like in san francisco?" + ); expect(response.city.toLowerCase()).toBe("san francisco"); -}) \ No newline at end of file +}); + +test.skip(".bind tools", async () => { + const weatherTool = z + .object({ + city: z.string().describe("The city to get the weather for"), + state: z.string().describe("The state to get the weather for").optional(), + }) + .describe("Get the weather for a city"); + const model = new BedrockChatWeb({ + region: process.env.BEDROCK_AWS_REGION, + model: "anthropic.claude-3-sonnet-20240229-v1:0", + maxRetries: 0, + credentials: { + secretAccessKey: process.env.BEDROCK_AWS_SECRET_ACCESS_KEY!, + accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID!, + }, + }); + const modelWithTools = model.bind({ + tools: [ + { + name: "weather_tool", + description: weatherTool.description, + input_schema: zodToJsonSchema(weatherTool), + }, + ], + }); + const response = await modelWithTools.invoke( + "Whats the weather like in san francisco?" + ); + console.log(response); + if (!response.tool_calls?.[0]) { + throw new Error("No tool calls found in response"); + } + expect(response.tool_calls[0].args.city.toLowerCase()).toBe("san francisco"); +}); diff --git a/libs/langchain-community/src/output_parsers/bedrock.ts b/libs/langchain-community/src/output_parsers/bedrock.ts index 7ab79a4e4314..2432295971a1 100644 --- a/libs/langchain-community/src/output_parsers/bedrock.ts +++ b/libs/langchain-community/src/output_parsers/bedrock.ts @@ -63,7 +63,9 @@ export class BedrockChatToolsOutputParser< if (!message.tool_calls || message.tool_calls.length === 0) { return []; } - const tool = message.tool_calls.find((tool) => tool.name === this.keyName); + const tool = message.tool_calls.find( + (tool) => tool.name === this.keyName + ); return tool; }); if (tools[0] === undefined) {