Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core[patch]: Allow runnable tools to take single string/ToolCall inputs #6096

Merged
merged 14 commits into from
Jul 17, 2024
1 change: 1 addition & 0 deletions examples/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
"@langchain/google-vertexai": "workspace:*",
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey there! 👋 I noticed that this PR adds a new peer dependency for "@langchain/langgraph". This comment is just to flag this change for the maintainers to review. Keep up the great work! 🚀

"@langchain/google-vertexai-web": "workspace:*",
"@langchain/groq": "workspace:*",
"@langchain/langgraph": "^0.0.28",
"@langchain/mistralai": "workspace:*",
"@langchain/mongodb": "workspace:*",
"@langchain/nomic": "workspace:*",
Expand Down
47 changes: 42 additions & 5 deletions langchain-core/src/runnables/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ import {
isIterableIterator,
isIterator,
} from "./iter.js";
import { ToolInputParsingException } from "../tools/tool_exception.js";
import { _isToolCall } from "../tools/utils.js";
import { ToolCall } from "../messages/tool.js";

export { type RunnableInterface, RunnableBatchOptions };

Expand Down Expand Up @@ -1095,7 +1098,7 @@ export abstract class Runnable<
name?: string;
description?: string;
schema: z.ZodType<T>;
}): RunnableToolLike<z.ZodType<T>, RunOutput> {
}): RunnableToolLike<z.ZodType<T | ToolCall>, RunOutput> {
return convertRunnableToTool<T, RunOutput>(this, fields);
}
}
Expand Down Expand Up @@ -2828,8 +2831,29 @@ export class RunnableToolLike<
schema: RunInput;

constructor(fields: RunnableToolLikeArgs<RunInput, RunOutput>) {
const sequence = RunnableSequence.from([
RunnableLambda.from(async (input) => {
let toolInput: z.TypeOf<RunInput>;

if (_isToolCall(input)) {
try {
toolInput = await this.schema.parseAsync(input.args);
} catch (e) {
throw new ToolInputParsingException(
`Received tool input did not match expected schema`,
JSON.stringify(input.args)
);
}
} else {
toolInput = input;
}
return toolInput;
}).withConfig({ runName: `${fields.name}:parse_input` }),
fields.bound,
]).withConfig({ runName: fields.name });

super({
bound: fields.bound,
bound: sequence,
config: fields.config ?? {},
});

Expand Down Expand Up @@ -2863,11 +2887,24 @@ export function convertRunnableToTool<RunInput, RunOutput>(
description?: string;
schema: z.ZodType<RunInput>;
}
): RunnableToolLike<z.ZodType<RunInput>, RunOutput> {
): RunnableToolLike<z.ZodType<RunInput | ToolCall>, RunOutput> {
const name = fields.name ?? runnable.getName();
const description = fields.description ?? fields.schema.description;
const description = fields.description ?? fields.schema?.description;

if (fields.schema.constructor === z.ZodString) {
return new RunnableToolLike<z.ZodType<RunInput | ToolCall>, RunOutput>({
name,
description,
schema: z
.object({
input: z.string(),
})
.transform((input) => input.input) as z.ZodType,
bound: runnable,
});
}

return new RunnableToolLike<z.ZodType<RunInput>, RunOutput>({
return new RunnableToolLike<z.ZodType<RunInput | ToolCall>, RunOutput>({
name,
description,
schema: fields.schema,
Expand Down
42 changes: 42 additions & 0 deletions langchain-core/src/runnables/tests/runnable_tools.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import { z } from "zod";
import { RunnableLambda, RunnableToolLike } from "../base.js";
import { FakeRetriever } from "../../utils/testing/index.js";
import { Document } from "../../documents/document.js";

test("Runnable asTool works", async () => {
const schema = z.object({
Expand Down Expand Up @@ -137,3 +139,43 @@ test("Runnable asTool uses Zod schema description if not provided", async () =>

expect(tool.description).toBe(description);
});

test("Runnable asTool can accept a string zod schema", async () => {
const lambda = RunnableLambda.from<string, string>((input) => {
return `${input}a`;
}).asTool({
name: "string_tool",
description: "A tool that appends 'a' to the input string",
schema: z.string(),
});

const result = await lambda.invoke("b");
expect(result).toBe("ba");
});

test("Runnables which dont accept ToolCalls as inputs can accept ToolCalls", async () => {
const pageContent = "Dogs are pretty cool, man!";
const retriever = new FakeRetriever({
output: [
new Document({
pageContent,
}),
],
});
const tool = retriever.asTool({
name: "pet_info_retriever",
description: "Get information about pets.",
schema: z.string(),
});

const result = await tool.invoke({
type: "tool_call",
name: "pet_info_retriever",
args: {
input: "dogs",
},
id: "string",
});
expect(result).toHaveLength(1);
expect(result[0].pageContent).toBe(pageContent);
});
99 changes: 50 additions & 49 deletions langchain-core/src/tools/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,13 @@ import {
} from "../runnables/config.js";
import type { RunnableFunc, RunnableInterface } from "../runnables/base.js";
import { ToolCall, ToolMessage } from "../messages/tool.js";
import { ZodAny } from "../types/zod.js";
import { ZodObjectAny } from "../types/zod.js";
import { MessageContent } from "../messages/base.js";
import { AsyncLocalStorageProviderSingleton } from "../singletons/index.js";
import { ToolInputParsingException } from "./tool_exception.js";
import { _isToolCall } from "./utils.js";

export { ToolInputParsingException };

export type ResponseFormat = "content" | "content_and_artifact" | string;

Expand All @@ -44,21 +48,7 @@ export interface ToolParams extends BaseLangChainParams {
responseFormat?: ResponseFormat;
}

/**
* Custom error class used to handle exceptions related to tool input parsing.
* It extends the built-in `Error` class and adds an optional `output`
* property that can hold the output that caused the exception.
*/
export class ToolInputParsingException extends Error {
output?: string;

constructor(message: string, output?: string) {
super(message);
this.output = output;
}
}

export interface StructuredToolInterface<T extends ZodAny = ZodAny>
export interface StructuredToolInterface<T extends ZodObjectAny = ZodObjectAny>
extends RunnableInterface<
(z.output<T> extends string ? string : never) | z.input<T> | ToolCall,
ToolReturnType
Expand Down Expand Up @@ -96,7 +86,7 @@ export interface StructuredToolInterface<T extends ZodAny = ZodAny>
* Base class for Tools that accept input of any shape defined by a Zod schema.
*/
export abstract class StructuredTool<
T extends ZodAny = ZodAny
T extends ZodObjectAny = ZodObjectAny
> extends BaseLangChain<
(z.output<T> extends string ? string : never) | z.input<T> | ToolCall,
ToolReturnType
Expand Down Expand Up @@ -259,7 +249,7 @@ export abstract class StructuredTool<
}
}

export interface ToolInterface<T extends ZodAny = ZodAny>
export interface ToolInterface<T extends ZodObjectAny = ZodObjectAny>
extends StructuredToolInterface<T> {
/**
* @deprecated Use .invoke() instead. Will be removed in 0.3.0.
Expand All @@ -279,7 +269,7 @@ export interface ToolInterface<T extends ZodAny = ZodAny>
/**
* Base class for Tools that accept input as a string.
*/
export abstract class Tool extends StructuredTool<ZodAny> {
export abstract class Tool extends StructuredTool<ZodObjectAny> {
schema = z
.object({ input: z.string().optional() })
.transform((obj) => obj.input);
Expand Down Expand Up @@ -328,8 +318,9 @@ export interface DynamicToolInput extends BaseDynamicToolInput {
/**
* Interface for the input parameters of the DynamicStructuredTool class.
*/
export interface DynamicStructuredToolInput<T extends ZodAny = ZodAny>
extends BaseDynamicToolInput {
export interface DynamicStructuredToolInput<
T extends ZodObjectAny = ZodObjectAny
> extends BaseDynamicToolInput {
func: (
input: BaseDynamicToolInput["responseFormat"] extends "content_and_artifact"
? ToolCall
Expand Down Expand Up @@ -393,7 +384,7 @@ export class DynamicTool extends Tool {
* provided function when the tool is called.
*/
export class DynamicStructuredTool<
T extends ZodAny = ZodAny
T extends ZodObjectAny = ZodObjectAny
> extends StructuredTool<T> {
static lc_name() {
return "DynamicStructuredTool";
Expand Down Expand Up @@ -456,11 +447,11 @@ export abstract class BaseToolkit {

/**
* Parameters for the tool function.
* @template {ZodAny} RunInput The input schema for the tool.
* @template {any} RunOutput The output type for the tool.
* @template {ZodObjectAny | z.ZodString = ZodObjectAny} RunInput The input schema for the tool. Either any Zod object, or a Zod string.
*/
interface ToolWrapperParams<RunInput extends ZodAny = ZodAny>
extends ToolParams {
interface ToolWrapperParams<
RunInput extends ZodObjectAny | z.ZodString = ZodObjectAny
> extends ToolParams {
/**
* The name of the tool. If using with an LLM, this
* will be passed as the tool name.
Expand Down Expand Up @@ -491,33 +482,52 @@ interface ToolWrapperParams<RunInput extends ZodAny = ZodAny>

/**
* Creates a new StructuredTool instance with the provided function, name, description, and schema.
*
* @function
* @template {RunInput extends ZodAny = ZodAny} RunInput The input schema for the tool. This corresponds to the input type when the tool is invoked.
* @template {RunOutput = any} RunOutput The output type for the tool. This corresponds to the output type when the tool is invoked.
* @template {FuncInput extends z.infer<RunInput> | ToolCall = z.infer<RunInput>} FuncInput The input type for the function.
* @template {ZodObjectAny | z.ZodString = ZodObjectAny} T The input schema for the tool. Either any Zod object, or a Zod string.
*
* @param {RunnableFunc<z.infer<RunInput> | ToolCall, RunOutput>} func - The function to invoke when the tool is called.
* @param fields - An object containing the following properties:
* @param {RunnableFunc<z.output<T>, ToolReturnType>} func - The function to invoke when the tool is called.
* @param {ToolWrapperParams<T>} fields - An object containing the following properties:
* @param {string} fields.name The name of the tool.
* @param {string | undefined} fields.description The description of the tool. Defaults to either the description on the Zod schema, or `${fields.name} tool`.
* @param {z.ZodObject<any, any, any, any>} fields.schema The Zod schema defining the input for the tool.
* @param {ZodObjectAny | z.ZodString | undefined} fields.schema The Zod schema defining the input for the tool. If undefined, it will default to a Zod string schema.
*
* @returns {DynamicStructuredTool<RunInput, RunOutput>} A new StructuredTool instance.
* @returns {DynamicStructuredTool<T>} A new StructuredTool instance.
*/
export function tool<T extends ZodAny = ZodAny>(
export function tool<T extends z.ZodString = z.ZodString>(
func: RunnableFunc<z.output<T>, ToolReturnType>,
fields: ToolWrapperParams<T>
): DynamicStructuredTool<T> {
const schema =
fields.schema ??
z.object({ input: z.string().optional() }).transform((obj) => obj.input);
): DynamicTool;

export function tool<T extends ZodObjectAny = ZodObjectAny>(
func: RunnableFunc<z.output<T>, ToolReturnType>,
fields: ToolWrapperParams<T>
): DynamicStructuredTool<T>;

export function tool<T extends ZodObjectAny = ZodObjectAny>(
bracesproul marked this conversation as resolved.
Show resolved Hide resolved
func: RunnableFunc<z.output<T>, ToolReturnType>,
fields: ToolWrapperParams<T>
): DynamicStructuredTool<T> | DynamicTool {
// If the schema is not provided, or it's a string schema, create a DynamicTool
if (!fields.schema || !fields.schema.shape) {
return new DynamicTool({
name: fields.name,
description:
fields.description ??
fields.schema?.description ??
`${fields.name} tool`,
responseFormat: fields.responseFormat,
func,
});
}

const description =
fields.description ?? schema.description ?? `${fields.name} tool`;
fields.description ?? fields.schema.description ?? `${fields.name} tool`;

return new DynamicStructuredTool({
name: fields.name,
description,
schema: schema as T,
schema: fields.schema,
bracesproul marked this conversation as resolved.
Show resolved Hide resolved
bracesproul marked this conversation as resolved.
Show resolved Hide resolved
// TODO: Consider moving into DynamicStructuredTool constructor
func: async (input, runManager, config) => {
return new Promise((resolve, reject) => {
Expand All @@ -540,15 +550,6 @@ export function tool<T extends ZodAny = ZodAny>(
});
}

function _isToolCall(toolCall?: unknown): toolCall is ToolCall {
return !!(
toolCall &&
typeof toolCall === "object" &&
"type" in toolCall &&
toolCall.type === "tool_call"
);
}

function _formatToolOutput(params: {
content: unknown;
name: string;
Expand Down
16 changes: 16 additions & 0 deletions langchain-core/src/tools/tests/tools.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,19 @@ test("Returns tool message if responseFormat is content_and_artifact and returns
expect(toolResult.artifact).toEqual({ location: "San Francisco" });
expect(toolResult.name).toBe("weather");
});

test("Tool can accept single string input", async () => {
const stringTool = tool<z.ZodString>(
(input: string): string => {
return `${input}a`;
},
{
name: "string_tool",
description: "A tool that appends 'a' to the input string",
schema: z.string(),
}
);

const result = await stringTool.invoke("b");
expect(result).toBe("ba");
});
13 changes: 13 additions & 0 deletions langchain-core/src/tools/tool_exception.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
/**
* Custom error class used to handle exceptions related to tool input parsing.
* It extends the built-in `Error` class and adds an optional `output`
* property that can hold the output that caused the exception.
*/
export class ToolInputParsingException extends Error {
bracesproul marked this conversation as resolved.
Show resolved Hide resolved
output?: string;

constructor(message: string, output?: string) {
super(message);
this.output = output;
}
}
10 changes: 10 additions & 0 deletions langchain-core/src/tools/utils.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import { ToolCall } from "../messages/tool.js";

export function _isToolCall(toolCall?: unknown): toolCall is ToolCall {
return !!(
toolCall &&
typeof toolCall === "object" &&
"type" in toolCall &&
toolCall.type === "tool_call"
);
}
2 changes: 1 addition & 1 deletion langchain-core/src/types/zod.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import type { z } from "zod";

// eslint-disable-next-line @typescript-eslint/no-explicit-any
export type ZodAny = z.ZodObject<any, any, any, any>;
export type ZodObjectAny = z.ZodObject<any, any, any, any>;
bracesproul marked this conversation as resolved.
Show resolved Hide resolved
Loading
Loading