From 1c4b2d83cacea6e4bd0b537356cc3bc0970ef848 Mon Sep 17 00:00:00 2001 From: Brace Sproul Date: Wed, 17 Jul 2024 11:03:12 -0700 Subject: [PATCH] core[patch]: Allow runnable tools to take single string/`ToolCall` inputs (#6096) * core[patch]: Allow runnable tools to take single string inputs * add test for tool func * chore: lint files * cr * cr * cr * fix types * rename ZodAny to ZodObjectAny * docstring nits * fiox * cr * cr --- examples/package.json | 1 + langchain-core/src/runnables/base.ts | 46 +++++++- .../runnables/tests/runnable_tools.test.ts | 42 ++++++++ langchain-core/src/tools/index.ts | 102 +++++++++--------- langchain-core/src/tools/tests/tools.test.ts | 16 +++ langchain-core/src/tools/utils.ts | 24 +++++ langchain-core/src/types/zod.ts | 2 +- yarn.lock | 17 +++ 8 files changed, 194 insertions(+), 56 deletions(-) create mode 100644 langchain-core/src/tools/utils.ts diff --git a/examples/package.json b/examples/package.json index 56fb064ed456..8f81162219bd 100644 --- a/examples/package.json +++ b/examples/package.json @@ -48,6 +48,7 @@ "@langchain/google-vertexai": "workspace:*", "@langchain/google-vertexai-web": "workspace:*", "@langchain/groq": "workspace:*", + "@langchain/langgraph": "^0.0.28", "@langchain/mistralai": "workspace:*", "@langchain/mongodb": "workspace:*", "@langchain/nomic": "workspace:*", diff --git a/langchain-core/src/runnables/base.ts b/langchain-core/src/runnables/base.ts index ac099057ffc3..ea0f00bd402b 100644 --- a/langchain-core/src/runnables/base.ts +++ b/langchain-core/src/runnables/base.ts @@ -52,6 +52,8 @@ import { isIterableIterator, isIterator, } from "./iter.js"; +import { _isToolCall, ToolInputParsingException } from "../tools/utils.js"; +import { ToolCall } from "../messages/tool.js"; export { type RunnableInterface, RunnableBatchOptions }; @@ -1095,7 +1097,7 @@ export abstract class Runnable< name?: string; description?: string; schema: z.ZodType; - }): RunnableToolLike, RunOutput> { + }): RunnableToolLike, RunOutput> { return convertRunnableToTool(this, fields); } } @@ -2828,8 +2830,29 @@ export class RunnableToolLike< schema: RunInput; constructor(fields: RunnableToolLikeArgs) { + const sequence = RunnableSequence.from([ + RunnableLambda.from(async (input) => { + let toolInput: z.TypeOf; + + if (_isToolCall(input)) { + try { + toolInput = await this.schema.parseAsync(input.args); + } catch (e) { + throw new ToolInputParsingException( + `Received tool input did not match expected schema`, + JSON.stringify(input.args) + ); + } + } else { + toolInput = input; + } + return toolInput; + }).withConfig({ runName: `${fields.name}:parse_input` }), + fields.bound, + ]).withConfig({ runName: fields.name }); + super({ - bound: fields.bound, + bound: sequence, config: fields.config ?? {}, }); @@ -2863,11 +2886,24 @@ export function convertRunnableToTool( description?: string; schema: z.ZodType; } -): RunnableToolLike, RunOutput> { +): RunnableToolLike, RunOutput> { const name = fields.name ?? runnable.getName(); - const description = fields.description ?? fields.schema.description; + const description = fields.description ?? fields.schema?.description; + + if (fields.schema.constructor === z.ZodString) { + return new RunnableToolLike, RunOutput>({ + name, + description, + schema: z + .object({ + input: z.string(), + }) + .transform((input) => input.input) as z.ZodType, + bound: runnable, + }); + } - return new RunnableToolLike, RunOutput>({ + return new RunnableToolLike, RunOutput>({ name, description, schema: fields.schema, diff --git a/langchain-core/src/runnables/tests/runnable_tools.test.ts b/langchain-core/src/runnables/tests/runnable_tools.test.ts index 4c16aea077c7..b8ac95cf4940 100644 --- a/langchain-core/src/runnables/tests/runnable_tools.test.ts +++ b/langchain-core/src/runnables/tests/runnable_tools.test.ts @@ -1,5 +1,7 @@ import { z } from "zod"; import { RunnableLambda, RunnableToolLike } from "../base.js"; +import { FakeRetriever } from "../../utils/testing/index.js"; +import { Document } from "../../documents/document.js"; test("Runnable asTool works", async () => { const schema = z.object({ @@ -137,3 +139,43 @@ test("Runnable asTool uses Zod schema description if not provided", async () => expect(tool.description).toBe(description); }); + +test("Runnable asTool can accept a string zod schema", async () => { + const lambda = RunnableLambda.from((input) => { + return `${input}a`; + }).asTool({ + name: "string_tool", + description: "A tool that appends 'a' to the input string", + schema: z.string(), + }); + + const result = await lambda.invoke("b"); + expect(result).toBe("ba"); +}); + +test("Runnables which dont accept ToolCalls as inputs can accept ToolCalls", async () => { + const pageContent = "Dogs are pretty cool, man!"; + const retriever = new FakeRetriever({ + output: [ + new Document({ + pageContent, + }), + ], + }); + const tool = retriever.asTool({ + name: "pet_info_retriever", + description: "Get information about pets.", + schema: z.string(), + }); + + const result = await tool.invoke({ + type: "tool_call", + name: "pet_info_retriever", + args: { + input: "dogs", + }, + id: "string", + }); + expect(result).toHaveLength(1); + expect(result[0].pageContent).toBe(pageContent); +}); diff --git a/langchain-core/src/tools/index.ts b/langchain-core/src/tools/index.ts index 09a36c41aff9..8286bb7cc33d 100644 --- a/langchain-core/src/tools/index.ts +++ b/langchain-core/src/tools/index.ts @@ -16,9 +16,12 @@ import { } from "../runnables/config.js"; import type { RunnableFunc, RunnableInterface } from "../runnables/base.js"; import { ToolCall, ToolMessage } from "../messages/tool.js"; -import { ZodAny } from "../types/zod.js"; +import { ZodObjectAny } from "../types/zod.js"; import { MessageContent } from "../messages/base.js"; import { AsyncLocalStorageProviderSingleton } from "../singletons/index.js"; +import { _isToolCall, ToolInputParsingException } from "./utils.js"; + +export { ToolInputParsingException }; export type ResponseFormat = "content" | "content_and_artifact" | string; @@ -44,21 +47,7 @@ export interface ToolParams extends BaseLangChainParams { responseFormat?: ResponseFormat; } -/** - * Custom error class used to handle exceptions related to tool input parsing. - * It extends the built-in `Error` class and adds an optional `output` - * property that can hold the output that caused the exception. - */ -export class ToolInputParsingException extends Error { - output?: string; - - constructor(message: string, output?: string) { - super(message); - this.output = output; - } -} - -export interface StructuredToolInterface +export interface StructuredToolInterface extends RunnableInterface< (z.output extends string ? string : never) | z.input | ToolCall, ToolReturnType @@ -96,7 +85,7 @@ export interface StructuredToolInterface * Base class for Tools that accept input of any shape defined by a Zod schema. */ export abstract class StructuredTool< - T extends ZodAny = ZodAny + T extends ZodObjectAny = ZodObjectAny > extends BaseLangChain< (z.output extends string ? string : never) | z.input | ToolCall, ToolReturnType @@ -259,7 +248,7 @@ export abstract class StructuredTool< } } -export interface ToolInterface +export interface ToolInterface extends StructuredToolInterface { /** * @deprecated Use .invoke() instead. Will be removed in 0.3.0. @@ -279,7 +268,7 @@ export interface ToolInterface /** * Base class for Tools that accept input as a string. */ -export abstract class Tool extends StructuredTool { +export abstract class Tool extends StructuredTool { schema = z .object({ input: z.string().optional() }) .transform((obj) => obj.input); @@ -328,8 +317,9 @@ export interface DynamicToolInput extends BaseDynamicToolInput { /** * Interface for the input parameters of the DynamicStructuredTool class. */ -export interface DynamicStructuredToolInput - extends BaseDynamicToolInput { +export interface DynamicStructuredToolInput< + T extends ZodObjectAny = ZodObjectAny +> extends BaseDynamicToolInput { func: ( input: BaseDynamicToolInput["responseFormat"] extends "content_and_artifact" ? ToolCall @@ -393,7 +383,7 @@ export class DynamicTool extends Tool { * provided function when the tool is called. */ export class DynamicStructuredTool< - T extends ZodAny = ZodAny + T extends ZodObjectAny = ZodObjectAny > extends StructuredTool { static lc_name() { return "DynamicStructuredTool"; @@ -456,11 +446,11 @@ export abstract class BaseToolkit { /** * Parameters for the tool function. - * @template {ZodAny} RunInput The input schema for the tool. - * @template {any} RunOutput The output type for the tool. + * @template {ZodObjectAny | z.ZodString = ZodObjectAny} RunInput The input schema for the tool. Either any Zod object, or a Zod string. */ -interface ToolWrapperParams - extends ToolParams { +interface ToolWrapperParams< + RunInput extends ZodObjectAny | z.ZodString = ZodObjectAny +> extends ToolParams { /** * The name of the tool. If using with an LLM, this * will be passed as the tool name. @@ -491,33 +481,54 @@ interface ToolWrapperParams /** * Creates a new StructuredTool instance with the provided function, name, description, and schema. + * * @function - * @template {RunInput extends ZodAny = ZodAny} RunInput The input schema for the tool. This corresponds to the input type when the tool is invoked. - * @template {RunOutput = any} RunOutput The output type for the tool. This corresponds to the output type when the tool is invoked. - * @template {FuncInput extends z.infer | ToolCall = z.infer} FuncInput The input type for the function. + * @template {ZodObjectAny | z.ZodString = ZodObjectAny} T The input schema for the tool. Either any Zod object, or a Zod string. * - * @param {RunnableFunc | ToolCall, RunOutput>} func - The function to invoke when the tool is called. - * @param fields - An object containing the following properties: + * @param {RunnableFunc, ToolReturnType>} func - The function to invoke when the tool is called. + * @param {ToolWrapperParams} fields - An object containing the following properties: * @param {string} fields.name The name of the tool. * @param {string | undefined} fields.description The description of the tool. Defaults to either the description on the Zod schema, or `${fields.name} tool`. - * @param {z.ZodObject} fields.schema The Zod schema defining the input for the tool. + * @param {ZodObjectAny | z.ZodString | undefined} fields.schema The Zod schema defining the input for the tool. If undefined, it will default to a Zod string schema. * - * @returns {DynamicStructuredTool} A new StructuredTool instance. + * @returns {DynamicStructuredTool} A new StructuredTool instance. */ -export function tool( +export function tool( func: RunnableFunc, ToolReturnType>, fields: ToolWrapperParams -): DynamicStructuredTool { - const schema = - fields.schema ?? - z.object({ input: z.string().optional() }).transform((obj) => obj.input); +): DynamicTool; + +export function tool( + func: RunnableFunc, ToolReturnType>, + fields: ToolWrapperParams +): DynamicStructuredTool; + +export function tool( + func: RunnableFunc, ToolReturnType>, + fields: ToolWrapperParams +): + | DynamicStructuredTool + | DynamicTool { + // If the schema is not provided, or it's a string schema, create a DynamicTool + if (!fields.schema || !("shape" in fields.schema) || !fields.schema.shape) { + return new DynamicTool({ + name: fields.name, + description: + fields.description ?? + fields.schema?.description ?? + `${fields.name} tool`, + responseFormat: fields.responseFormat, + func, + }); + } const description = - fields.description ?? schema.description ?? `${fields.name} tool`; - return new DynamicStructuredTool({ + fields.description ?? fields.schema.description ?? `${fields.name} tool`; + + return new DynamicStructuredTool({ name: fields.name, description, - schema: schema as T, + schema: fields.schema as T extends ZodObjectAny ? T : ZodObjectAny, // TODO: Consider moving into DynamicStructuredTool constructor func: async (input, runManager, config) => { return new Promise((resolve, reject) => { @@ -540,15 +551,6 @@ export function tool( }); } -function _isToolCall(toolCall?: unknown): toolCall is ToolCall { - return !!( - toolCall && - typeof toolCall === "object" && - "type" in toolCall && - toolCall.type === "tool_call" - ); -} - function _formatToolOutput(params: { content: unknown; name: string; diff --git a/langchain-core/src/tools/tests/tools.test.ts b/langchain-core/src/tools/tests/tools.test.ts index b514b99c1827..bf577a4a1dc9 100644 --- a/langchain-core/src/tools/tests/tools.test.ts +++ b/langchain-core/src/tools/tests/tools.test.ts @@ -99,3 +99,19 @@ test("Returns tool message if responseFormat is content_and_artifact and returns expect(toolResult.artifact).toEqual({ location: "San Francisco" }); expect(toolResult.name).toBe("weather"); }); + +test("Tool can accept single string input", async () => { + const stringTool = tool( + (input: string): string => { + return `${input}a`; + }, + { + name: "string_tool", + description: "A tool that appends 'a' to the input string", + schema: z.string(), + } + ); + + const result = await stringTool.invoke("b"); + expect(result).toBe("ba"); +}); diff --git a/langchain-core/src/tools/utils.ts b/langchain-core/src/tools/utils.ts new file mode 100644 index 000000000000..b9c5bd8b384d --- /dev/null +++ b/langchain-core/src/tools/utils.ts @@ -0,0 +1,24 @@ +import { ToolCall } from "../messages/tool.js"; + +export function _isToolCall(toolCall?: unknown): toolCall is ToolCall { + return !!( + toolCall && + typeof toolCall === "object" && + "type" in toolCall && + toolCall.type === "tool_call" + ); +} + +/** + * Custom error class used to handle exceptions related to tool input parsing. + * It extends the built-in `Error` class and adds an optional `output` + * property that can hold the output that caused the exception. + */ +export class ToolInputParsingException extends Error { + output?: string; + + constructor(message: string, output?: string) { + super(message); + this.output = output; + } +} diff --git a/langchain-core/src/types/zod.ts b/langchain-core/src/types/zod.ts index d864170ddafa..faaa92b3bff2 100644 --- a/langchain-core/src/types/zod.ts +++ b/langchain-core/src/types/zod.ts @@ -1,4 +1,4 @@ import type { z } from "zod"; // eslint-disable-next-line @typescript-eslint/no-explicit-any -export type ZodAny = z.ZodObject; +export type ZodObjectAny = z.ZodObject; diff --git a/yarn.lock b/yarn.lock index 2497a15f11f9..8fcc427d9b2e 100644 --- a/yarn.lock +++ b/yarn.lock @@ -11341,6 +11341,22 @@ __metadata: languageName: node linkType: hard +"@langchain/langgraph@npm:^0.0.28": + version: 0.0.28 + resolution: "@langchain/langgraph@npm:0.0.28" + dependencies: + "@langchain/core": ">=0.2.16 <0.3.0" + uuid: ^10.0.0 + zod: ^3.23.8 + peerDependencies: + better-sqlite3: ^9.5.0 + peerDependenciesMeta: + better-sqlite3: + optional: true + checksum: 1465791026ccd6eaa13a2f2d03b8fb9f0972a8c23b9da1cfd581074f413ea60ef860de6d704c6a3b49f7425f23d6ba49c23255167ae83ab7d70dc00cc0560ce2 + languageName: node + linkType: hard + "@langchain/mistralai@workspace:*, @langchain/mistralai@workspace:libs/langchain-mistralai": version: 0.0.0-use.local resolution: "@langchain/mistralai@workspace:libs/langchain-mistralai" @@ -24929,6 +24945,7 @@ __metadata: "@langchain/google-vertexai": "workspace:*" "@langchain/google-vertexai-web": "workspace:*" "@langchain/groq": "workspace:*" + "@langchain/langgraph": ^0.0.28 "@langchain/mistralai": "workspace:*" "@langchain/mongodb": "workspace:*" "@langchain/nomic": "workspace:*"