Skip to content

Commit

Permalink
implement .bind, fix call options, add docs
Browse files Browse the repository at this point in the history
  • Loading branch information
bracesproul committed Jun 11, 2024
1 parent 03d0954 commit 864905e
Show file tree
Hide file tree
Showing 6 changed files with 237 additions and 38 deletions.
28 changes: 28 additions & 0 deletions docs/core_docs/docs/integrations/chat/bedrock.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,31 @@ Anthropic Claude-3 models hosted on Bedrock have multimodal capabilities and can
import BedrockMultimodalExample from "@examples/models/chat/integration_bedrock_multimodal.ts";

<CodeBlock language="typescript">{BedrockMultimodalExample}</CodeBlock>

### Tool calling

:::info
Not all Bedrock models support tool calling. Please refer to the [model documentation](https://docs.aws.amazon.com/bedrock/latest/APIReference/welcome.html) for more information.
:::

The examples below demonstrate how to use tool calling, along with the `withStructuredOutput` method to easily compose structured output LLM calls.

import ToolCalling from "@examples/models/chat/integration_bedrock_tools.ts";

<CodeBlock language="typescript">{ToolCalling}</CodeBlock>

:::tip
See the LangSmith trace [here](https://smith.langchain.com/public/003a684d-90eb-406e-a146-8ee5e617921b/r)
:::

#### `.withStructuredOutput({ ... })`

Using the `.withStructuredOutput` method, you can easily make the LLM return structured output, given only a Zod or JSON schema:

import WSOExample from "@examples/models/chat/integration_bedrock_wso.ts";

<CodeBlock language="typescript">{WSOExample}</CodeBlock>

:::tip
See the LangSmith trace [here](https://smith.langchain.com/public/1f7b1ad8-e4ac-4965-8ce1-fae06005f3d7/r)
:::
64 changes: 64 additions & 0 deletions examples/src/models/chat/integration_bedrock_tools.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@

import { BedrockChat } from "@langchain/community/chat_models/bedrock";
// Or, from web environments:
// import { BedrockChat } from "@langchain/community/chat_models/bedrock/web";
import { z } from "zod";
import { zodToJsonSchema } from "zod-to-json-schema";

const model = new BedrockChat({
region: process.env.BEDROCK_AWS_REGION,
model: "anthropic.claude-3-sonnet-20240229-v1:0",
maxRetries: 0,
credentials: {
secretAccessKey: process.env.BEDROCK_AWS_SECRET_ACCESS_KEY!,
accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID!,
},
});

const weatherSchema = z
.object({
city: z.string().describe("The city to get the weather for"),
state: z.string().describe("The state to get the weather for").optional(),
})
.describe("Get the weather for a city");

const modelWithTools = model.bindTools([
{
name: "weather_tool",
description: weatherSchema.description,
input_schema: zodToJsonSchema(weatherSchema),
},
]);
// Optionally, you can bind tools via the `.bind` method:
// const modelWithTools = model.bind({
// tools: [
// {
// name: "weather_tool",
// description: weatherSchema.description,
// input_schema: zodToJsonSchema(weatherSchema),
// },
// ],
// });

const res = await modelWithTools.invoke("What's the weather in New York?");
console.log(res);

/*
AIMessage {
additional_kwargs: { id: 'msg_bdrk_01JF7hb4PNQPywP4gnBbgpHi' },
response_metadata: {
stop_reason: 'tool_use',
usage: { input_tokens: 300, output_tokens: 85 }
},
tool_calls: [
{
name: 'weather_tool',
args: {
city: 'New York',
state: 'NY'
},
id: 'toolu_bdrk_01AtEZRTCKioFXqhoNcpgaV7'
}
],
}
*/
35 changes: 35 additions & 0 deletions examples/src/models/chat/integration_bedrock_wso.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@

import { BedrockChat } from "@langchain/community/chat_models/bedrock";
// Or, from web environments:
// import { BedrockChat } from "@langchain/community/chat_models/bedrock/web";
import { z } from "zod";

const model = new BedrockChat({
region: process.env.BEDROCK_AWS_REGION,
model: "anthropic.claude-3-sonnet-20240229-v1:0",
maxRetries: 0,
credentials: {
secretAccessKey: process.env.BEDROCK_AWS_SECRET_ACCESS_KEY!,
accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID!,
},
});

const weatherSchema = z
.object({
city: z.string().describe("The city to get the weather for"),
state: z.string().describe("The state to get the weather for").optional(),
})
.describe("Get the weather for a city");

const modelWithStructuredOutput = model.withStructuredOutput(weatherSchema, {
name: "weather_tool", // Optional, defaults to 'extract'
});

const res = await modelWithStructuredOutput.invoke(
"What's the weather in New York?"
);
console.log(res);

/*
{ city: 'New York', state: 'NY' }
*/
90 changes: 60 additions & 30 deletions libs/langchain-community/src/chat_models/bedrock/web.ts
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,29 @@ export function convertMessagesToPrompt(
throw new Error(`Provider ${provider} does not support chat.`);
}

function formatTools(
tools: (StructuredToolInterface | AnthropicTool)[]
): AnthropicTool[] {
return tools.map((tool) => {
if (isStructuredTool(tool)) {
return {
name: tool.name,
description: tool.description,
input_schema: zodToJsonSchema(tool.schema),
};
}
return tool;
});
}

export interface BedrockChatCallOptions extends BaseChatModelCallOptions {
tools?: (StructuredToolInterface | AnthropicTool)[];
}

export interface BedrockChatFields
extends Partial<BaseBedrockInput>,
BaseChatModelParams {}

/**
* A type of Large Language Model (LLM) that interacts with the Bedrock
* service. It extends the base `LLM` class and implements the
Expand Down Expand Up @@ -208,7 +231,10 @@ export function convertMessagesToPrompt(
* runStreaming().catch(console.error);
* ```
*/
export class BedrockChat extends BaseChatModel implements BaseBedrockInput {
export class BedrockChat
extends BaseChatModel<BedrockChatCallOptions, AIMessageChunk>
implements BaseBedrockInput
{
model = "amazon.titan-tg1-large";

region: string;
Expand Down Expand Up @@ -281,7 +307,7 @@ export class BedrockChat extends BaseChatModel implements BaseBedrockInput {
return "BedrockChat";
}

constructor(fields?: Partial<BaseBedrockInput> & BaseChatModelParams) {
constructor(fields?: BedrockChatFields) {
super(fields ?? {});

this.model = fields?.model ?? this.model;
Expand Down Expand Up @@ -331,11 +357,14 @@ export class BedrockChat extends BaseChatModel implements BaseBedrockInput {
}

override invocationParams(options?: this["ParsedCallOptions"]) {
const callOptionTools = formatTools(options?.tools ?? []);
return {
tools: this._anthropicTools,
tools: [...(this._anthropicTools ?? []), ...callOptionTools],
temperature: this.temperature,
max_tokens: this.maxTokens,
stop: options?.stop,
stop: options?.stop ?? this.stopSequences,
modelKwargs: this.modelKwargs,
guardrailConfig: this.guardrailConfig,
};
}

Expand All @@ -353,7 +382,7 @@ export class BedrockChat extends BaseChatModel implements BaseBedrockInput {

async _generate(
messages: BaseMessage[],
options: Partial<BaseChatModelParams>,
options: Partial<this["ParsedCallOptions"]>,
runManager?: CallbackManagerForLLMRun
): Promise<ChatResult> {
if (this.streaming) {
Expand Down Expand Up @@ -381,7 +410,7 @@ export class BedrockChat extends BaseChatModel implements BaseBedrockInput {

async _generateNonStreaming(
messages: BaseMessage[],
options: Partial<BaseChatModelParams>,
options: Partial<this["ParsedCallOptions"]>,
_runManager?: CallbackManagerForLLMRun
): Promise<ChatResult> {
const service = "bedrock-runtime";
Expand Down Expand Up @@ -425,26 +454,34 @@ export class BedrockChat extends BaseChatModel implements BaseBedrockInput {
}
) {
const { bedrockMethod, endpointHost, provider } = fields;
const {
max_tokens,
temperature,
stop,
modelKwargs,
guardrailConfig,
tools,
} = this.invocationParams(options);
const inputBody = this.usesMessagesApi
? BedrockLLMInputOutputAdapter.prepareMessagesInput(
provider,
messages,
this.maxTokens,
this.temperature,
options.stop ?? this.stopSequences,
this.modelKwargs,
this.guardrailConfig,
this._anthropicTools
max_tokens,
temperature,
stop,
modelKwargs,
guardrailConfig,
tools
)
: BedrockLLMInputOutputAdapter.prepareInput(
provider,
convertMessagesToPromptAnthropic(messages),
this.maxTokens,
this.temperature,
options.stop ?? this.stopSequences,
this.modelKwargs,
max_tokens,
temperature,
stop,
modelKwargs,
fields.bedrockMethod,
this.guardrailConfig
guardrailConfig
);

const url = new URL(
Expand Down Expand Up @@ -694,28 +731,19 @@ export class BedrockChat extends BaseChatModel implements BaseBedrockInput {

override bindTools(
tools: (StructuredToolInterface | AnthropicTool)[],
_kwargs?: Partial<BaseChatModelCallOptions>
_kwargs?: Partial<this["ParsedCallOptions"]>
): Runnable<
BaseLanguageModelInput,
BaseMessageChunk,
BaseChatModelCallOptions
this["ParsedCallOptions"]
> {
const provider = this.model.split(".")[0];
if (provider !== "anthropic") {
throw new Error(
"Currently, tool calling through Bedrock is only supported for Anthropic models."
);
}
this._anthropicTools = tools.map((tool) => {
if (isStructuredTool(tool)) {
return {
name: tool.name,
description: tool.description,
input_schema: zodToJsonSchema(tool.schema),
};
}
return tool;
});
this._anthropicTools = formatTools(tools);
return this;
}

Expand Down Expand Up @@ -762,7 +790,9 @@ export class BedrockChat extends BaseChatModel implements BaseBedrockInput {
const method = config?.method;
const includeRaw = config?.includeRaw;
if (method === "jsonMode") {
throw new Error(`Anthropic only supports "functionCalling" as a method.`);
throw new Error(
`BedrockChat only supports "functionCalling" as a method.`
);
}

let functionName = name ?? "extract";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@ import { test, expect } from "@jest/globals";
import { HumanMessage } from "@langchain/core/messages";
import { AgentExecutor, createToolCallingAgent } from "langchain/agents";
import { ChatPromptTemplate } from "@langchain/core/prompts";
import { z } from "zod";
import { zodToJsonSchema } from "zod-to-json-schema";
import { BedrockChat as BedrockChatWeb } from "../bedrock/web.js";
import { TavilySearchResults } from "../../tools/tavily_search.js";
import { z } from "zod";

void testChatModel(
"Test Bedrock chat model Generating search queries: Command-r",
Expand Down Expand Up @@ -386,10 +387,12 @@ test.skip.each([
});

test.skip("withStructuredOutput", async () => {
const weatherTool = z.object({
city: z.string().describe("The city to get the weather for"),
state: z.string().describe("The state to get the weather for").optional(),
}).describe("Get the weather for a city");
const weatherTool = z
.object({
city: z.string().describe("The city to get the weather for"),
state: z.string().describe("The state to get the weather for").optional(),
})
.describe("Get the weather for a city");
const model = new BedrockChatWeb({
region: process.env.BEDROCK_AWS_REGION,
model: "anthropic.claude-3-sonnet-20240229-v1:0",
Expand All @@ -402,6 +405,43 @@ test.skip("withStructuredOutput", async () => {
const modelWithTools = model.withStructuredOutput(weatherTool, {
name: "weather",
});
const response = await modelWithTools.invoke("Whats the weather like in san francisco?");
const response = await modelWithTools.invoke(
"Whats the weather like in san francisco?"
);
expect(response.city.toLowerCase()).toBe("san francisco");
})
});

test.skip(".bind tools", async () => {
const weatherTool = z
.object({
city: z.string().describe("The city to get the weather for"),
state: z.string().describe("The state to get the weather for").optional(),
})
.describe("Get the weather for a city");
const model = new BedrockChatWeb({
region: process.env.BEDROCK_AWS_REGION,
model: "anthropic.claude-3-sonnet-20240229-v1:0",
maxRetries: 0,
credentials: {
secretAccessKey: process.env.BEDROCK_AWS_SECRET_ACCESS_KEY!,
accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID!,
},
});
const modelWithTools = model.bind({
tools: [
{
name: "weather_tool",
description: weatherTool.description,
input_schema: zodToJsonSchema(weatherTool),
},
],
});
const response = await modelWithTools.invoke(
"Whats the weather like in san francisco?"
);
console.log(response);
if (!response.tool_calls?.[0]) {
throw new Error("No tool calls found in response");
}
expect(response.tool_calls[0].args.city.toLowerCase()).toBe("san francisco");
});
4 changes: 3 additions & 1 deletion libs/langchain-community/src/output_parsers/bedrock.ts
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,9 @@ export class BedrockChatToolsOutputParser<
if (!message.tool_calls || message.tool_calls.length === 0) {
return [];
}
const tool = message.tool_calls.find((tool) => tool.name === this.keyName);
const tool = message.tool_calls.find(
(tool) => tool.name === this.keyName
);
return tool;
});
if (tools[0] === undefined) {
Expand Down

0 comments on commit 864905e

Please sign in to comment.