Skip to content

Commit

Permalink
Add tools and tool_choice param for ChatOpenAI (langchain-ai#3186)
Browse files Browse the repository at this point in the history
* Add tools and tool_choice param

* Support streaming and merging tool call chunks

* Finish tool message support

* Repurpose tools
  • Loading branch information
jacoblee93 authored Nov 8, 2023
1 parent ed61f3c commit 5169279
Show file tree
Hide file tree
Showing 6 changed files with 399 additions and 65 deletions.
102 changes: 68 additions & 34 deletions langchain/src/chat_models/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@ import {
FunctionMessageChunk,
HumanMessageChunk,
SystemMessageChunk,
ToolMessage,
ToolMessageChunk,
} from "../schema/index.js";
import { StructuredTool } from "../tools/base.js";
import { formatToOpenAIFunction } from "../tools/convert_to_openai.js";
import { formatToOpenAITool } from "../tools/convert_to_openai.js";
import {
AzureOpenAIInput,
OpenAICallOptions,
Expand Down Expand Up @@ -60,7 +62,8 @@ function extractGenericMessageCustomRole(message: ChatMessage) {
message.role !== "system" &&
message.role !== "assistant" &&
message.role !== "user" &&
message.role !== "function"
message.role !== "function" &&
message.role !== "tool"
) {
console.warn(`Unknown message role: ${message.role}`);
}
Expand All @@ -79,6 +82,8 @@ function messageToOpenAIRole(message: BaseMessage): OpenAIRoleEnum {
return "user";
case "function":
return "function";
case "tool":
return "tool";
case "generic": {
if (!ChatMessage.isInstance(message))
throw new Error("Invalid generic chat message");
Expand All @@ -96,6 +101,7 @@ function openAIResponseToChatMessage(
case "assistant":
return new AIMessage(message.content || "", {
function_call: message.function_call,
tool_calls: message.tool_calls,
});
default:
return new ChatMessage(message.content || "", message.role ?? "unknown");
Expand All @@ -114,6 +120,10 @@ function _convertDeltaToMessageChunk(
additional_kwargs = {
function_call: delta.function_call,
};
} else if (delta.tool_calls) {
additional_kwargs = {
tool_calls: delta.tool_calls,
};
} else {
additional_kwargs = {};
}
Expand All @@ -129,15 +139,37 @@ function _convertDeltaToMessageChunk(
additional_kwargs,
name: delta.name,
});
} else if (role === "tool") {
return new ToolMessageChunk({
content,
additional_kwargs,
tool_call_id: delta.tool_call_id,
});
} else {
return new ChatMessageChunk({ content, role });
}
}

function convertMessagesToOpenAIParams(messages: BaseMessage[]) {
// TODO: Function messages do not support array content, fix cast
return messages.map(
(message) =>
({
role: messageToOpenAIRole(message),
content: message.content,
name: message.name,
function_call: message.additional_kwargs.function_call,
tool_calls: message.additional_kwargs.tool_calls,
tool_call_id: (message as ToolMessage).tool_call_id,
} as OpenAICompletionParam)
);
}

export interface ChatOpenAICallOptions
extends OpenAICallOptions,
BaseFunctionCallOptions {
tools?: StructuredTool[];
tools?: StructuredTool[] | OpenAIClient.ChatCompletionTool[];
tool_choice?: OpenAIClient.ChatCompletionToolChoiceOption;
promptIndex?: number;
response_format?: { type: "json_object" };
seed?: number;
Expand Down Expand Up @@ -179,6 +211,7 @@ export class ChatOpenAI<
"function_call",
"functions",
"tools",
"tool_choice",
"promptIndex",
"response_format",
"seed",
Expand Down Expand Up @@ -343,7 +376,20 @@ export class ChatOpenAI<
invocationParams(
options?: this["ParsedCallOptions"]
): Omit<OpenAIClient.Chat.ChatCompletionCreateParams, "messages"> {
return {
function isStructuredToolArray(
tools?: unknown[]
): tools is StructuredTool[] {
return (
tools !== undefined &&
tools.every((tool) =>
Array.isArray((tool as StructuredTool).lc_namespace)
)
);
}
const params: Omit<
OpenAIClient.Chat.ChatCompletionCreateParams,
"messages"
> = {
model: this.modelName,
temperature: this.temperature,
top_p: this.topP,
Expand All @@ -355,16 +401,17 @@ export class ChatOpenAI<
stop: options?.stop ?? this.stop,
user: this.user,
stream: this.streaming,
functions:
options?.functions ??
(options?.tools
? options?.tools.map(formatToOpenAIFunction)
: undefined),
functions: options?.functions,
function_call: options?.function_call,
tools: isStructuredToolArray(options?.tools)
? options?.tools.map(formatToOpenAITool)
: options?.tools,
tool_choice: options?.tool_choice,
response_format: options?.response_format,
seed: options?.seed,
...this.modelKwargs,
};
return params;
}

/** @ignore */
Expand All @@ -386,17 +433,8 @@ export class ChatOpenAI<
options: this["ParsedCallOptions"],
runManager?: CallbackManagerForLLMRun
): AsyncGenerator<ChatGenerationChunk> {
const messagesMapped: OpenAICompletionParam[] = messages.map(
// TODO: Function messages do not support array content, fix cast
(message) =>
({
role: messageToOpenAIRole(message),
content: message.content,
name: message.name,
function_call: message.additional_kwargs
.function_call as OpenAIClient.Chat.ChatCompletionMessage.FunctionCall,
} as OpenAICompletionParam)
);
const messagesMapped: OpenAICompletionParam[] =
convertMessagesToOpenAIParams(messages);
const params = {
...this.invocationParams(options),
messages: messagesMapped,
Expand All @@ -419,7 +457,7 @@ export class ChatOpenAI<
};
if (typeof chunk.content !== "string") {
console.log(
"[WARNING:] Received non-string content from OpenAI. This is currently not supported."
"[WARNING]: Received non-string content from OpenAI. This is currently not supported."
);
continue;
}
Expand Down Expand Up @@ -461,17 +499,7 @@ export class ChatOpenAI<
const tokenUsage: TokenUsage = {};
const params = this.invocationParams(options);
const messagesMapped: OpenAICompletionParam[] =
// TODO: Function messages do not support array content, fix cast
messages.map(
(message) =>
({
role: messageToOpenAIRole(message),
content: message.content,
name: message.name,
function_call: message.additional_kwargs
.function_call as OpenAIClient.Chat.ChatCompletionMessage.FunctionCall,
} as OpenAICompletionParam)
);
convertMessagesToOpenAIParams(messages);

if (params.stream) {
const stream = this._streamResponseChunks(messages, options, runManager);
Expand Down Expand Up @@ -658,7 +686,12 @@ export class ChatOpenAI<
}
if (openAIMessage.additional_kwargs.function_call?.arguments) {
count += await this.getNumTokens(
openAIMessage.additional_kwargs.function_call?.arguments
// Remove newlines and spaces
JSON.stringify(
JSON.parse(
openAIMessage.additional_kwargs.function_call?.arguments
)
)
);
}

Expand Down Expand Up @@ -851,7 +884,8 @@ export class PromptLayerChatOpenAI extends ChatOpenAI {
| "system"
| "assistant"
| "user"
| "function",
| "function"
| "tool",
content: message.content,
};
} else {
Expand Down
176 changes: 176 additions & 0 deletions langchain/src/chat_models/tests/chatopenai-extended.int.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
import { test, expect } from "@jest/globals";
import { ChatOpenAI } from "../openai.js";
import { HumanMessage, ToolMessage } from "../../schema/index.js";

test("Test ChatOpenAI JSON mode", async () => {
const chat = new ChatOpenAI({
modelName: "gpt-3.5-turbo-1106",
maxTokens: 128,
}).bind({
response_format: {
type: "json_object",
},
});
const message = new HumanMessage("Hello!");
const res = await chat.invoke([["system", "Only return JSON"], message]);
console.log(JSON.stringify(res));
});

test("Test ChatOpenAI seed", async () => {
const chat = new ChatOpenAI({
modelName: "gpt-3.5-turbo-1106",
maxTokens: 128,
temperature: 1,
}).bind({
seed: 123454930394983,
});
const message = new HumanMessage("Say something random!");
const res = await chat.invoke([message]);
console.log(JSON.stringify(res));
const res2 = await chat.invoke([message]);
expect(res).toEqual(res2);
});

test("Test ChatOpenAI tool calling", async () => {
const chat = new ChatOpenAI({
modelName: "gpt-3.5-turbo-1106",
maxTokens: 128,
}).bind({
tools: [
{
type: "function",
function: {
name: "get_current_weather",
description: "Get the current weather in a given location",
parameters: {
type: "object",
properties: {
location: {
type: "string",
description: "The city and state, e.g. San Francisco, CA",
},
unit: { type: "string", enum: ["celsius", "fahrenheit"] },
},
required: ["location"],
},
},
},
],
tool_choice: "auto",
});
const res = await chat.invoke([
["human", "What's the weather like in San Francisco, Tokyo, and Paris?"],
]);
console.log(JSON.stringify(res));
expect(res.additional_kwargs.tool_calls?.length).toBeGreaterThan(1);
});

test("Test ChatOpenAI tool calling with ToolMessages", async () => {
function getCurrentWeather(location: string) {
if (location.toLowerCase().includes("tokyo")) {
return JSON.stringify({ location, temperature: "10", unit: "celsius" });
} else if (location.toLowerCase().includes("san francisco")) {
return JSON.stringify({
location,
temperature: "72",
unit: "fahrenheit",
});
} else {
return JSON.stringify({ location, temperature: "22", unit: "celsius" });
}
}
const chat = new ChatOpenAI({
modelName: "gpt-3.5-turbo-1106",
maxTokens: 128,
}).bind({
tools: [
{
type: "function",
function: {
name: "get_current_weather",
description: "Get the current weather in a given location",
parameters: {
type: "object",
properties: {
location: {
type: "string",
description: "The city and state, e.g. San Francisco, CA",
},
unit: { type: "string", enum: ["celsius", "fahrenheit"] },
},
required: ["location"],
},
},
},
],
tool_choice: "auto",
});
const res = await chat.invoke([
["human", "What's the weather like in San Francisco, Tokyo, and Paris?"],
]);
console.log(JSON.stringify(res));
expect(res.additional_kwargs.tool_calls?.length).toBeGreaterThan(1);
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const toolMessages = res.additional_kwargs.tool_calls!.map(
(toolCall) =>
new ToolMessage({
tool_call_id: toolCall.id,
name: toolCall.function.name,
content: getCurrentWeather(
JSON.parse(toolCall.function.arguments).location
),
})
);
const finalResponse = await chat.invoke([
["human", "What's the weather like in San Francisco, Tokyo, and Paris?"],
res,
...toolMessages,
]);
console.log(finalResponse);
});

test("Test ChatOpenAI tool calling with streaming", async () => {
const chat = new ChatOpenAI({
modelName: "gpt-3.5-turbo-1106",
maxTokens: 256,
}).bind({
tools: [
{
type: "function",
function: {
name: "get_current_weather",
description: "Get the current weather in a given location",
parameters: {
type: "object",
properties: {
location: {
type: "string",
description: "The city and state, e.g. San Francisco, CA",
},
unit: { type: "string", enum: ["celsius", "fahrenheit"] },
},
required: ["location"],
},
},
},
],
tool_choice: "auto",
});
const stream = await chat.stream([
["human", "What's the weather like in San Francisco, Tokyo, and Paris?"],
]);
let finalChunk;
const chunks = [];
for await (const chunk of stream) {
console.log(chunk.additional_kwargs.tool_calls);
chunks.push(chunk);
if (!finalChunk) {
finalChunk = chunk;
} else {
finalChunk = finalChunk.concat(chunk);
}
}
expect(chunks.length).toBeGreaterThan(1);
console.log(finalChunk?.additional_kwargs.tool_calls);
expect(finalChunk?.additional_kwargs.tool_calls?.length).toBeGreaterThan(1);
});
Loading

0 comments on commit 5169279

Please sign in to comment.