Skip to content

Commit

Permalink
feat(core): Preserve direct tool outputs, pass raw tool call into too…
Browse files Browse the repository at this point in the history
…ls if available (langchain-ai#7340)
  • Loading branch information
jacoblee93 authored and syntaxsec committed Dec 13, 2024
1 parent e40d130 commit d4af337
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 35 deletions.
24 changes: 23 additions & 1 deletion langchain-core/src/messages/tool.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,30 @@ export interface ToolMessageFieldsWithToolCallId extends BaseMessageFields {
status?: "success" | "error";
}

/**
* Marker parameter for objects that tools can return directly.
*
* If a custom BaseTool is invoked with a ToolCall and the output of custom code is
* not an instance of DirectToolOutput, the output will automatically be coerced to
* a string and wrapped in a ToolMessage.
*/
export interface DirectToolOutput {
readonly lc_direct_tool_output: boolean;
}

export function isDirectToolOutput(x: unknown): x is DirectToolOutput {
return (
x != null &&
typeof x === "object" &&
"lc_direct_tool_output" in x &&
x.lc_direct_tool_output === true
);
}

/**
* Represents a tool message in a conversation.
*/
export class ToolMessage extends BaseMessage {
export class ToolMessage extends BaseMessage implements DirectToolOutput {
static lc_name() {
return "ToolMessage";
}
Expand All @@ -40,6 +60,8 @@ export class ToolMessage extends BaseMessage {
return { tool_call_id: "tool_call_id" };
}

lc_direct_tool_output = true;

/**
* Status of the tool invocation.
* @version 0.2.19
Expand Down
53 changes: 32 additions & 21 deletions langchain-core/src/tools/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import {
type RunnableConfig,
} from "../runnables/config.js";
import type { RunnableFunc, RunnableInterface } from "../runnables/base.js";
import { ToolCall, ToolMessage } from "../messages/tool.js";
import { isDirectToolOutput, ToolCall, ToolMessage } from "../messages/tool.js";
import { MessageContent } from "../messages/base.js";
import { AsyncLocalStorageProviderSingleton } from "../singletons/index.js";
import { _isToolCall, ToolInputParsingException } from "./utils.js";
Expand Down Expand Up @@ -57,6 +57,11 @@ export interface ToolParams extends BaseLangChainParams {
verboseParsingErrors?: boolean;
}

export type ToolRunnableConfig<
// eslint-disable-next-line @typescript-eslint/no-explicit-any
ConfigurableFieldType extends Record<string, any> = Record<string, any>
> = RunnableConfig<ConfigurableFieldType> & { toolCall?: ToolCall };

/**
* Schema for defining tools.
*
Expand Down Expand Up @@ -159,7 +164,7 @@ export abstract class StructuredTool<
protected abstract _call(
arg: z.output<T>,
runManager?: CallbackManagerForToolRun,
parentConfig?: RunnableConfig
parentConfig?: ToolRunnableConfig
): Promise<ToolReturnType>;

/**
Expand All @@ -182,21 +187,23 @@ export abstract class StructuredTool<
| ToolCall
| undefined;

let enrichedConfig: ToolRunnableConfig = ensureConfig(config);
if (_isToolCall(input)) {
tool_call_id = input.id;
toolInput = input.args;
enrichedConfig = {
...enrichedConfig,
toolCall: input,
configurable: {
...enrichedConfig.configurable,
tool_call_id,
},
};
} else {
toolInput = input;
}

const ensuredConfig = ensureConfig(config);
return this.call(toolInput, {
...ensuredConfig,
configurable: {
...ensuredConfig.configurable,
tool_call_id,
},
});
return this.call(toolInput, enrichedConfig);
}

/**
Expand All @@ -211,8 +218,8 @@ export abstract class StructuredTool<
* @returns A Promise that resolves with a string.
*/
async call(
arg: (z.output<T> extends string ? string : never) | z.input<T> | ToolCall,
configArg?: Callbacks | RunnableConfig,
arg: (z.output<T> extends string ? string : never) | z.input<T>,
configArg?: Callbacks | ToolRunnableConfig,
/** @deprecated */
tags?: string[]
): Promise<ToolReturnType> {
Expand All @@ -229,7 +236,7 @@ export abstract class StructuredTool<
}

const config = parseCallbackConfigArg(configArg);
const callbackManager_ = await CallbackManager.configure(
const callbackManager_ = CallbackManager.configure(
config.callbacks,
this.callbacks,
config.tags || tags,
Expand Down Expand Up @@ -350,7 +357,7 @@ export interface DynamicToolInput extends BaseDynamicToolInput {
func: (
input: string,
runManager?: CallbackManagerForToolRun,
config?: RunnableConfig
config?: ToolRunnableConfig
) => Promise<ToolReturnType>;
}

Expand Down Expand Up @@ -400,7 +407,7 @@ export class DynamicTool extends Tool {
*/
async call(
arg: string | undefined | z.input<this["schema"]> | ToolCall,
configArg?: RunnableConfig | Callbacks
configArg?: ToolRunnableConfig | Callbacks
): Promise<ToolReturnType> {
const config = parseCallbackConfigArg(configArg);
if (config.runName === undefined) {
Expand All @@ -413,7 +420,7 @@ export class DynamicTool extends Tool {
async _call(
input: string,
runManager?: CallbackManagerForToolRun,
parentConfig?: RunnableConfig
parentConfig?: ToolRunnableConfig
): Promise<ToolReturnType> {
return this.func(input, runManager, parentConfig);
}
Expand Down Expand Up @@ -553,26 +560,30 @@ interface ToolWrapperParams<
* @returns {DynamicStructuredTool<T>} A new StructuredTool instance.
*/
export function tool<T extends z.ZodString>(
func: RunnableFunc<z.output<T>, ToolReturnType>,
func: RunnableFunc<z.output<T>, ToolReturnType, ToolRunnableConfig>,
fields: ToolWrapperParams<T>
): DynamicTool;

export function tool<T extends ZodObjectAny>(
func: RunnableFunc<z.output<T>, ToolReturnType>,
func: RunnableFunc<z.output<T>, ToolReturnType, ToolRunnableConfig>,
fields: ToolWrapperParams<T>
): DynamicStructuredTool<T>;

// eslint-disable-next-line @typescript-eslint/no-explicit-any
export function tool<T extends Record<string, any>>(
func: RunnableFunc<T, ToolReturnType>,
func: RunnableFunc<T, ToolReturnType, ToolRunnableConfig>,
fields: ToolWrapperParams<T>
): DynamicStructuredTool<T>;

export function tool<
// eslint-disable-next-line @typescript-eslint/no-explicit-any
T extends ZodObjectAny | z.ZodString | Record<string, any> = ZodObjectAny
>(
func: RunnableFunc<T extends ZodObjectAny ? z.output<T> : T, ToolReturnType>,
func: RunnableFunc<
T extends ZodObjectAny ? z.output<T> : T,
ToolReturnType,
ToolRunnableConfig
>,
fields: ToolWrapperParams<T>
):
| DynamicStructuredTool<T extends ZodObjectAny ? T : ZodObjectAny>
Expand Down Expand Up @@ -649,7 +660,7 @@ function _formatToolOutput(params: {
toolCallId?: string;
}): ToolReturnType {
const { content, artifact, toolCallId } = params;
if (toolCallId) {
if (toolCallId && !isDirectToolOutput(content)) {
if (
typeof content === "string" ||
(Array.isArray(content) &&
Expand Down
86 changes: 73 additions & 13 deletions langchain-core/src/tools/tests/tools.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,17 @@ import { z } from "zod";

import { DynamicStructuredTool, tool } from "../index.js";
import { ToolMessage } from "../../messages/tool.js";
import { RunnableConfig } from "../../runnables/types.js";

test("Tool should error if responseFormat is content_and_artifact but the function doesn't return a tuple", async () => {
const weatherSchema = z.object({
location: z.string(),
});

const weatherTool = tool(
(_) => {
// Should be able to type this as base RunnableConfig without issue,
// though true type is more specific
(_, _config: RunnableConfig) => {
return "str";
},
{
Expand Down Expand Up @@ -51,9 +54,15 @@ test("Does not return tool message if responseFormat is content_and_artifact and
const weatherSchema = z.object({
location: z.string(),
});
const toolCall = {
args: { location: "San Francisco" },
name: "weather",
type: "tool_call",
} as const;

const weatherTool = tool(
(input) => {
(input, config) => {
expect(config.toolCall).toEqual(toolCall);
return ["msg_content", input];
},
{
Expand All @@ -63,11 +72,7 @@ test("Does not return tool message if responseFormat is content_and_artifact and
}
);

const toolResult = await weatherTool.invoke({
args: { location: "San Francisco" },
name: "weather",
type: "tool_call",
});
const toolResult = await weatherTool.invoke(toolCall);

expect(toolResult).toBe("msg_content");
});
Expand All @@ -77,8 +82,16 @@ test("Returns tool message if responseFormat is content_and_artifact and returns
location: z.string(),
});

const toolCall = {
id: "testid",
args: { location: "San Francisco" },
name: "weather",
type: "tool_call",
} as const;

const weatherTool = tool(
(input) => {
(input, config) => {
expect(config.toolCall).toEqual(toolCall);
return ["msg_content", input];
},
{
Expand All @@ -88,23 +101,63 @@ test("Returns tool message if responseFormat is content_and_artifact and returns
}
);

const toolResult = await weatherTool.invoke({
const toolResult = await weatherTool.invoke(toolCall);

expect(toolResult).toBeInstanceOf(ToolMessage);
expect(toolResult.content).toBe("msg_content");
expect(toolResult.artifact).toEqual({ location: "San Francisco" });
expect(toolResult.name).toBe("weather");
});

test("Does not double wrap a returned tool message even if a tool call with id is passed in", async () => {
const weatherSchema = z.object({
location: z.string(),
});

const toolCall = {
id: "testid",
args: { location: "San Francisco" },
name: "weather",
type: "tool_call",
});
} as const;

const weatherTool = tool(
(_, config) => {
expect(config.toolCall).toEqual(toolCall);
return new ToolMessage({
tool_call_id: "not_original",
content: "bar",
name: "baz",
});
},
{
name: "weather",
schema: weatherSchema,
}
);

const toolResult = await weatherTool.invoke(toolCall);

expect(toolResult).toBeInstanceOf(ToolMessage);
expect(toolResult.content).toBe("msg_content");
expect(toolResult.artifact).toEqual({ location: "San Francisco" });
expect(toolResult.name).toBe("weather");
expect(toolResult.tool_call_id).toBe("not_original");
expect(toolResult.content).toBe("bar");
expect(toolResult.name).toBe("baz");
});

test("Tool can accept single string input", async () => {
const toolCall = {
id: "testid",
args: { input: "b" },
name: "string_tool",
type: "tool_call",
} as const;

const stringTool = tool<z.ZodString>(
(input: string, config): string => {
expect(config).toMatchObject({ configurable: { foo: "bar" } });
if (config.configurable.usesToolCall) {
expect(config.toolCall).toEqual(toolCall);
}
return `${input}a`;
},
{
Expand All @@ -116,6 +169,13 @@ test("Tool can accept single string input", async () => {

const result = await stringTool.invoke("b", { configurable: { foo: "bar" } });
expect(result).toBe("ba");

const result2 = await stringTool.invoke(toolCall, {
configurable: { foo: "bar", usesToolCall: true },
});
expect(result2).toBeInstanceOf(ToolMessage);
expect(result2.content).toBe("ba");
expect(result2.name).toBe("string_tool");
});

test("Tool declared with JSON schema", async () => {
Expand Down

0 comments on commit d4af337

Please sign in to comment.