diff --git a/langchain/src/chat_models/openai.ts b/langchain/src/chat_models/openai.ts index 18b23e25e2a8..cbda733aea9c 100644 --- a/langchain/src/chat_models/openai.ts +++ b/langchain/src/chat_models/openai.ts @@ -13,9 +13,11 @@ import { FunctionMessageChunk, HumanMessageChunk, SystemMessageChunk, + ToolMessage, + ToolMessageChunk, } from "../schema/index.js"; import { StructuredTool } from "../tools/base.js"; -import { formatToOpenAIFunction } from "../tools/convert_to_openai.js"; +import { formatToOpenAITool } from "../tools/convert_to_openai.js"; import { AzureOpenAIInput, OpenAICallOptions, @@ -60,7 +62,8 @@ function extractGenericMessageCustomRole(message: ChatMessage) { message.role !== "system" && message.role !== "assistant" && message.role !== "user" && - message.role !== "function" + message.role !== "function" && + message.role !== "tool" ) { console.warn(`Unknown message role: ${message.role}`); } @@ -79,6 +82,8 @@ function messageToOpenAIRole(message: BaseMessage): OpenAIRoleEnum { return "user"; case "function": return "function"; + case "tool": + return "tool"; case "generic": { if (!ChatMessage.isInstance(message)) throw new Error("Invalid generic chat message"); @@ -96,6 +101,7 @@ function openAIResponseToChatMessage( case "assistant": return new AIMessage(message.content || "", { function_call: message.function_call, + tool_calls: message.tool_calls, }); default: return new ChatMessage(message.content || "", message.role ?? "unknown"); @@ -114,6 +120,10 @@ function _convertDeltaToMessageChunk( additional_kwargs = { function_call: delta.function_call, }; + } else if (delta.tool_calls) { + additional_kwargs = { + tool_calls: delta.tool_calls, + }; } else { additional_kwargs = {}; } @@ -129,15 +139,37 @@ function _convertDeltaToMessageChunk( additional_kwargs, name: delta.name, }); + } else if (role === "tool") { + return new ToolMessageChunk({ + content, + additional_kwargs, + tool_call_id: delta.tool_call_id, + }); } else { return new ChatMessageChunk({ content, role }); } } +function convertMessagesToOpenAIParams(messages: BaseMessage[]) { + // TODO: Function messages do not support array content, fix cast + return messages.map( + (message) => + ({ + role: messageToOpenAIRole(message), + content: message.content, + name: message.name, + function_call: message.additional_kwargs.function_call, + tool_calls: message.additional_kwargs.tool_calls, + tool_call_id: (message as ToolMessage).tool_call_id, + } as OpenAICompletionParam) + ); +} + export interface ChatOpenAICallOptions extends OpenAICallOptions, BaseFunctionCallOptions { - tools?: StructuredTool[]; + tools?: StructuredTool[] | OpenAIClient.ChatCompletionTool[]; + tool_choice?: OpenAIClient.ChatCompletionToolChoiceOption; promptIndex?: number; response_format?: { type: "json_object" }; seed?: number; @@ -179,6 +211,7 @@ export class ChatOpenAI< "function_call", "functions", "tools", + "tool_choice", "promptIndex", "response_format", "seed", @@ -343,7 +376,20 @@ export class ChatOpenAI< invocationParams( options?: this["ParsedCallOptions"] ): Omit { - return { + function isStructuredToolArray( + tools?: unknown[] + ): tools is StructuredTool[] { + return ( + tools !== undefined && + tools.every((tool) => + Array.isArray((tool as StructuredTool).lc_namespace) + ) + ); + } + const params: Omit< + OpenAIClient.Chat.ChatCompletionCreateParams, + "messages" + > = { model: this.modelName, temperature: this.temperature, top_p: this.topP, @@ -355,16 +401,17 @@ export class ChatOpenAI< stop: options?.stop ?? this.stop, user: this.user, stream: this.streaming, - functions: - options?.functions ?? - (options?.tools - ? options?.tools.map(formatToOpenAIFunction) - : undefined), + functions: options?.functions, function_call: options?.function_call, + tools: isStructuredToolArray(options?.tools) + ? options?.tools.map(formatToOpenAITool) + : options?.tools, + tool_choice: options?.tool_choice, response_format: options?.response_format, seed: options?.seed, ...this.modelKwargs, }; + return params; } /** @ignore */ @@ -386,17 +433,8 @@ export class ChatOpenAI< options: this["ParsedCallOptions"], runManager?: CallbackManagerForLLMRun ): AsyncGenerator { - const messagesMapped: OpenAICompletionParam[] = messages.map( - // TODO: Function messages do not support array content, fix cast - (message) => - ({ - role: messageToOpenAIRole(message), - content: message.content, - name: message.name, - function_call: message.additional_kwargs - .function_call as OpenAIClient.Chat.ChatCompletionMessage.FunctionCall, - } as OpenAICompletionParam) - ); + const messagesMapped: OpenAICompletionParam[] = + convertMessagesToOpenAIParams(messages); const params = { ...this.invocationParams(options), messages: messagesMapped, @@ -419,7 +457,7 @@ export class ChatOpenAI< }; if (typeof chunk.content !== "string") { console.log( - "[WARNING:] Received non-string content from OpenAI. This is currently not supported." + "[WARNING]: Received non-string content from OpenAI. This is currently not supported." ); continue; } @@ -461,17 +499,7 @@ export class ChatOpenAI< const tokenUsage: TokenUsage = {}; const params = this.invocationParams(options); const messagesMapped: OpenAICompletionParam[] = - // TODO: Function messages do not support array content, fix cast - messages.map( - (message) => - ({ - role: messageToOpenAIRole(message), - content: message.content, - name: message.name, - function_call: message.additional_kwargs - .function_call as OpenAIClient.Chat.ChatCompletionMessage.FunctionCall, - } as OpenAICompletionParam) - ); + convertMessagesToOpenAIParams(messages); if (params.stream) { const stream = this._streamResponseChunks(messages, options, runManager); @@ -658,7 +686,12 @@ export class ChatOpenAI< } if (openAIMessage.additional_kwargs.function_call?.arguments) { count += await this.getNumTokens( - openAIMessage.additional_kwargs.function_call?.arguments + // Remove newlines and spaces + JSON.stringify( + JSON.parse( + openAIMessage.additional_kwargs.function_call?.arguments + ) + ) ); } @@ -851,7 +884,8 @@ export class PromptLayerChatOpenAI extends ChatOpenAI { | "system" | "assistant" | "user" - | "function", + | "function" + | "tool", content: message.content, }; } else { diff --git a/langchain/src/chat_models/tests/chatopenai-extended.int.test.ts b/langchain/src/chat_models/tests/chatopenai-extended.int.test.ts new file mode 100644 index 000000000000..995da52a22f1 --- /dev/null +++ b/langchain/src/chat_models/tests/chatopenai-extended.int.test.ts @@ -0,0 +1,176 @@ +import { test, expect } from "@jest/globals"; +import { ChatOpenAI } from "../openai.js"; +import { HumanMessage, ToolMessage } from "../../schema/index.js"; + +test("Test ChatOpenAI JSON mode", async () => { + const chat = new ChatOpenAI({ + modelName: "gpt-3.5-turbo-1106", + maxTokens: 128, + }).bind({ + response_format: { + type: "json_object", + }, + }); + const message = new HumanMessage("Hello!"); + const res = await chat.invoke([["system", "Only return JSON"], message]); + console.log(JSON.stringify(res)); +}); + +test("Test ChatOpenAI seed", async () => { + const chat = new ChatOpenAI({ + modelName: "gpt-3.5-turbo-1106", + maxTokens: 128, + temperature: 1, + }).bind({ + seed: 123454930394983, + }); + const message = new HumanMessage("Say something random!"); + const res = await chat.invoke([message]); + console.log(JSON.stringify(res)); + const res2 = await chat.invoke([message]); + expect(res).toEqual(res2); +}); + +test("Test ChatOpenAI tool calling", async () => { + const chat = new ChatOpenAI({ + modelName: "gpt-3.5-turbo-1106", + maxTokens: 128, + }).bind({ + tools: [ + { + type: "function", + function: { + name: "get_current_weather", + description: "Get the current weather in a given location", + parameters: { + type: "object", + properties: { + location: { + type: "string", + description: "The city and state, e.g. San Francisco, CA", + }, + unit: { type: "string", enum: ["celsius", "fahrenheit"] }, + }, + required: ["location"], + }, + }, + }, + ], + tool_choice: "auto", + }); + const res = await chat.invoke([ + ["human", "What's the weather like in San Francisco, Tokyo, and Paris?"], + ]); + console.log(JSON.stringify(res)); + expect(res.additional_kwargs.tool_calls?.length).toBeGreaterThan(1); +}); + +test("Test ChatOpenAI tool calling with ToolMessages", async () => { + function getCurrentWeather(location: string) { + if (location.toLowerCase().includes("tokyo")) { + return JSON.stringify({ location, temperature: "10", unit: "celsius" }); + } else if (location.toLowerCase().includes("san francisco")) { + return JSON.stringify({ + location, + temperature: "72", + unit: "fahrenheit", + }); + } else { + return JSON.stringify({ location, temperature: "22", unit: "celsius" }); + } + } + const chat = new ChatOpenAI({ + modelName: "gpt-3.5-turbo-1106", + maxTokens: 128, + }).bind({ + tools: [ + { + type: "function", + function: { + name: "get_current_weather", + description: "Get the current weather in a given location", + parameters: { + type: "object", + properties: { + location: { + type: "string", + description: "The city and state, e.g. San Francisco, CA", + }, + unit: { type: "string", enum: ["celsius", "fahrenheit"] }, + }, + required: ["location"], + }, + }, + }, + ], + tool_choice: "auto", + }); + const res = await chat.invoke([ + ["human", "What's the weather like in San Francisco, Tokyo, and Paris?"], + ]); + console.log(JSON.stringify(res)); + expect(res.additional_kwargs.tool_calls?.length).toBeGreaterThan(1); + // eslint-disable-next-line @typescript-eslint/no-non-null-assertion + const toolMessages = res.additional_kwargs.tool_calls!.map( + (toolCall) => + new ToolMessage({ + tool_call_id: toolCall.id, + name: toolCall.function.name, + content: getCurrentWeather( + JSON.parse(toolCall.function.arguments).location + ), + }) + ); + const finalResponse = await chat.invoke([ + ["human", "What's the weather like in San Francisco, Tokyo, and Paris?"], + res, + ...toolMessages, + ]); + console.log(finalResponse); +}); + +test("Test ChatOpenAI tool calling with streaming", async () => { + const chat = new ChatOpenAI({ + modelName: "gpt-3.5-turbo-1106", + maxTokens: 256, + }).bind({ + tools: [ + { + type: "function", + function: { + name: "get_current_weather", + description: "Get the current weather in a given location", + parameters: { + type: "object", + properties: { + location: { + type: "string", + description: "The city and state, e.g. San Francisco, CA", + }, + unit: { type: "string", enum: ["celsius", "fahrenheit"] }, + }, + required: ["location"], + }, + }, + }, + ], + tool_choice: "auto", + }); + const stream = await chat.stream([ + ["human", "What's the weather like in San Francisco, Tokyo, and Paris?"], + ]); + let finalChunk; + const chunks = []; + for await (const chunk of stream) { + console.log(chunk.additional_kwargs.tool_calls); + chunks.push(chunk); + if (!finalChunk) { + finalChunk = chunk; + } else { + finalChunk = finalChunk.concat(chunk); + } + } + expect(chunks.length).toBeGreaterThan(1); + console.log(finalChunk?.additional_kwargs.tool_calls); + expect(finalChunk?.additional_kwargs.tool_calls?.length).toBeGreaterThan(1); +}); diff --git a/langchain/src/chat_models/tests/chatopenai.int.test.ts b/langchain/src/chat_models/tests/chatopenai.int.test.ts index 5f183d6d0dba..5d712f2b3bb2 100644 --- a/langchain/src/chat_models/tests/chatopenai.int.test.ts +++ b/langchain/src/chat_models/tests/chatopenai.int.test.ts @@ -407,7 +407,7 @@ test("Test ChatOpenAI stream method, timeout error thrown from SDK", async () => test("Function calling with streaming", async () => { let finalResult: BaseMessage | undefined; const modelForFunctionCalling = new ChatOpenAI({ - modelName: "gpt-4-1106-preview", + modelName: "gpt-3.5-turbo", temperature: 0, callbacks: [ { @@ -775,32 +775,3 @@ test("Test ChatOpenAI token usage reporting for streaming calls", async () => { expect(streamingTokenUsed).toEqual(nonStreamingTokenUsed); } }); - -test("Test ChatOpenAI JSON mode", async () => { - const chat = new ChatOpenAI({ - modelName: "gpt-4-1106-preview", - maxTokens: 128, - }).bind({ - response_format: { - type: "json_object", - }, - }); - const message = new HumanMessage("Hello!"); - const res = await chat.invoke([["system", "Only return JSON"], message]); - console.log(JSON.stringify(res)); -}); - -test("Test ChatOpenAI seed", async () => { - const chat = new ChatOpenAI({ - modelName: "gpt-4-1106-preview", - maxTokens: 128, - temperature: 1, - }).bind({ - seed: 123454930394983, - }); - const message = new HumanMessage("Say something random!"); - const res = await chat.invoke([message]); - console.log(JSON.stringify(res)); - const res2 = await chat.invoke([message]); - expect(res).toEqual(res2); -}); diff --git a/langchain/src/runnables/remote.ts b/langchain/src/runnables/remote.ts index 9c2d16db1eaf..5562938f1aec 100644 --- a/langchain/src/runnables/remote.ts +++ b/langchain/src/runnables/remote.ts @@ -17,6 +17,8 @@ import { HumanMessageChunk, SystemMessage, SystemMessageChunk, + ToolMessage, + ToolMessageChunk, } from "../schema/index.js"; import { StringPromptValue } from "../prompts/base.js"; import { ChatPromptValue } from "../prompts/chat.js"; @@ -71,6 +73,12 @@ function revive(obj: any): any { name: obj.name, }); } + if (obj.type === "tool") { + return new ToolMessage({ + content: obj.content, + tool_call_id: obj.tool_call_id, + }); + } if (obj.type === "ai") { return new AIMessage({ content: obj.content, @@ -99,6 +107,12 @@ function revive(obj: any): any { name: obj.name, }); } + if (obj.type === "tool") { + return new ToolMessageChunk({ + content: obj.content, + tool_call_id: obj.tool_call_id, + }); + } if (obj.type === "ai") { return new AIMessageChunk({ content: obj.content, diff --git a/langchain/src/schema/index.ts b/langchain/src/schema/index.ts index 802a09e8c379..c04553d77e6d 100644 --- a/langchain/src/schema/index.ts +++ b/langchain/src/schema/index.ts @@ -85,6 +85,7 @@ export interface StoredMessageData { content: string; role: string | undefined; name: string | undefined; + tool_call_id: string | undefined; // eslint-disable-next-line @typescript-eslint/no-explicit-any additional_kwargs?: Record; } @@ -99,7 +100,13 @@ export interface StoredGeneration { message?: StoredMessage; } -export type MessageType = "human" | "ai" | "generic" | "system" | "function"; +export type MessageType = + | "human" + | "ai" + | "generic" + | "system" + | "function" + | "tool"; export type MessageContent = | string @@ -114,6 +121,7 @@ export interface BaseMessageFields { name?: string; additional_kwargs?: { function_call?: OpenAIClient.Chat.ChatCompletionMessage.FunctionCall; + tool_calls?: OpenAIClient.Chat.ChatCompletionMessageToolCall[]; [key: string]: unknown; }; } @@ -126,6 +134,10 @@ export interface FunctionMessageFieldsWithName extends BaseMessageFields { name: string; } +export interface ToolMessageFieldsWithToolCallId extends BaseMessageFields { + tool_call_id: string; +} + function mergeContent( firstContent: MessageContent, secondContent: MessageContent @@ -232,6 +244,18 @@ export abstract class BaseMessage } } +// TODO: Deprecate when SDK typing is updated +export type OpenAIToolCall = OpenAIClient.ChatCompletionMessageToolCall & { + index: number; +}; + +function isOpenAIToolCallArray(value?: unknown): value is OpenAIToolCall[] { + return ( + Array.isArray(value) && + value.every((v) => typeof (v as OpenAIToolCall).index === "number") + ); +} + /** * Represents a chunk of a message, which can be concatenated with other * message chunks. It includes a method `_merge_kwargs_dict()` for merging @@ -264,6 +288,32 @@ export abstract class BaseMessageChunk extends BaseMessage { merged[key] as NonNullable, value as NonNullable ); + } else if ( + key === "tool_calls" && + isOpenAIToolCallArray(merged[key]) && + isOpenAIToolCallArray(value) + ) { + for (const toolCall of value) { + if (merged[key]?.[toolCall.index] !== undefined) { + merged[key] = merged[key]?.map((value, i) => { + if (i !== toolCall.index) { + return value; + } + return { + ...value, + ...toolCall, + function: { + name: toolCall.function.name ?? value.function.name, + arguments: + (value.function.arguments ?? "") + + (toolCall.function.arguments ?? ""), + }, + }; + }); + } else { + (merged[key] as OpenAIToolCall[])[toolCall.index] = toolCall; + } + } } else { throw new Error( `additional_kwargs[${key}] already exists in this message chunk.` @@ -467,6 +517,74 @@ export class FunctionMessageChunk extends BaseMessageChunk { } } +/** + * Represents a tool message in a conversation. + */ +export class ToolMessage extends BaseMessage { + static lc_name() { + return "ToolMessage"; + } + + tool_call_id: string; + + constructor(fields: ToolMessageFieldsWithToolCallId); + + constructor( + fields: string | BaseMessageFields, + tool_call_id: string, + name?: string + ); + + constructor( + fields: string | ToolMessageFieldsWithToolCallId, + tool_call_id?: string, + name?: string + ) { + if (typeof fields === "string") { + // eslint-disable-next-line no-param-reassign, @typescript-eslint/no-non-null-assertion + fields = { content: fields, name, tool_call_id: tool_call_id! }; + } + super(fields); + this.tool_call_id = fields.tool_call_id; + } + + _getType(): MessageType { + return "tool"; + } +} + +/** + * Represents a chunk of a function message, which can be concatenated + * with other function message chunks. + */ +export class ToolMessageChunk extends BaseMessageChunk { + tool_call_id: string; + + constructor(fields: ToolMessageFieldsWithToolCallId) { + super(fields); + this.tool_call_id = fields.tool_call_id; + } + + static lc_name() { + return "ToolMessageChunk"; + } + + _getType(): MessageType { + return "tool"; + } + + concat(chunk: ToolMessageChunk) { + return new ToolMessageChunk({ + content: mergeContent(this.content, chunk.content), + additional_kwargs: ToolMessageChunk._mergeAdditionalKwargs( + this.additional_kwargs, + chunk.additional_kwargs + ), + tool_call_id: this.tool_call_id, + }); + } +} + /** * Represents a chat message in a conversation. */ @@ -645,6 +763,7 @@ function mapV1MessageToStoredMessage( content: v1Message.text, role: v1Message.role, name: undefined, + tool_call_id: undefined, }, }; } @@ -666,6 +785,13 @@ export function mapStoredMessageToChatMessage(message: StoredMessage) { return new FunctionMessage( storedMessage.data as FunctionMessageFieldsWithName ); + case "tool": + if (storedMessage.data.tool_call_id === undefined) { + throw new Error("Tool call ID must be defined for tool messages"); + } + return new ToolMessage( + storedMessage.data as ToolMessageFieldsWithToolCallId + ); case "chat": { if (storedMessage.data.role === undefined) { throw new Error("Role must be defined for chat messages"); diff --git a/langchain/src/tools/convert_to_openai.ts b/langchain/src/tools/convert_to_openai.ts index 459eb41629cb..935bf6046306 100644 --- a/langchain/src/tools/convert_to_openai.ts +++ b/langchain/src/tools/convert_to_openai.ts @@ -18,3 +18,16 @@ export function formatToOpenAIFunction( parameters: zodToJsonSchema(tool.schema), }; } + +export function formatToOpenAITool( + tool: StructuredTool +): OpenAIClient.Chat.ChatCompletionTool { + return { + type: "function", + function: { + name: tool.name, + description: tool.description, + parameters: zodToJsonSchema(tool.schema), + }, + }; +}