Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(openai): Refactor to allow easier subclassing #7598

Merged
merged 3 commits into from
Jan 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
289 changes: 144 additions & 145 deletions libs/langchain-openai/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -145,145 +145,6 @@ export function messageToOpenAIRole(message: BaseMessage): OpenAIRoleEnum {
}
}

function openAIResponseToChatMessage(
message: OpenAIClient.Chat.Completions.ChatCompletionMessage,
rawResponse: OpenAIClient.Chat.Completions.ChatCompletion,
includeRawResponse?: boolean
): BaseMessage {
const rawToolCalls: OpenAIToolCall[] | undefined = message.tool_calls as
| OpenAIToolCall[]
| undefined;
switch (message.role) {
case "assistant": {
const toolCalls = [];
const invalidToolCalls = [];
for (const rawToolCall of rawToolCalls ?? []) {
try {
toolCalls.push(parseToolCall(rawToolCall, { returnId: true }));
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} catch (e: any) {
invalidToolCalls.push(makeInvalidToolCall(rawToolCall, e.message));
}
}
const additional_kwargs: Record<string, unknown> = {
function_call: message.function_call,
tool_calls: rawToolCalls,
};
if (includeRawResponse !== undefined) {
additional_kwargs.__raw_response = rawResponse;
}
const response_metadata: Record<string, unknown> | undefined = {
model_name: rawResponse.model,
...(rawResponse.system_fingerprint
? {
usage: { ...rawResponse.usage },
system_fingerprint: rawResponse.system_fingerprint,
}
: {}),
};

if (message.audio) {
additional_kwargs.audio = message.audio;
}

return new AIMessage({
content: message.content || "",
tool_calls: toolCalls,
invalid_tool_calls: invalidToolCalls,
additional_kwargs,
response_metadata,
id: rawResponse.id,
});
}
default:
return new ChatMessage(message.content || "", message.role ?? "unknown");
}
}

function _convertDeltaToMessageChunk(
// eslint-disable-next-line @typescript-eslint/no-explicit-any
delta: Record<string, any>,
rawResponse: OpenAIClient.Chat.Completions.ChatCompletionChunk,
defaultRole?: OpenAIRoleEnum,
includeRawResponse?: boolean
) {
const role = delta.role ?? defaultRole;
const content = delta.content ?? "";
let additional_kwargs: Record<string, unknown>;
if (delta.function_call) {
additional_kwargs = {
function_call: delta.function_call,
};
} else if (delta.tool_calls) {
additional_kwargs = {
tool_calls: delta.tool_calls,
};
} else {
additional_kwargs = {};
}
if (includeRawResponse) {
additional_kwargs.__raw_response = rawResponse;
}

if (delta.audio) {
additional_kwargs.audio = {
...delta.audio,
index: rawResponse.choices[0].index,
};
}

const response_metadata = { usage: { ...rawResponse.usage } };
if (role === "user") {
return new HumanMessageChunk({ content, response_metadata });
} else if (role === "assistant") {
const toolCallChunks: ToolCallChunk[] = [];
if (Array.isArray(delta.tool_calls)) {
for (const rawToolCall of delta.tool_calls) {
toolCallChunks.push({
name: rawToolCall.function?.name,
args: rawToolCall.function?.arguments,
id: rawToolCall.id,
index: rawToolCall.index,
type: "tool_call_chunk",
});
}
}
return new AIMessageChunk({
content,
tool_call_chunks: toolCallChunks,
additional_kwargs,
id: rawResponse.id,
response_metadata,
});
} else if (role === "system") {
return new SystemMessageChunk({ content, response_metadata });
} else if (role === "developer") {
return new SystemMessageChunk({
content,
response_metadata,
additional_kwargs: {
__openai_role__: "developer",
},
});
} else if (role === "function") {
return new FunctionMessageChunk({
content,
additional_kwargs,
name: delta.name,
response_metadata,
});
} else if (role === "tool") {
return new ToolMessageChunk({
content,
additional_kwargs,
tool_call_id: delta.tool_call_id,
response_metadata,
});
} else {
return new ChatMessageChunk({ content, role, response_metadata });
}
}

// Used in LangSmith, export is important here
export function _convertMessagesToOpenAIParams(
messages: BaseMessage[],
Expand Down Expand Up @@ -1290,6 +1151,146 @@ export class ChatOpenAI<
return params;
}

protected _convertOpenAIChatCompletionMessageToBaseMessage(
message: OpenAIClient.Chat.Completions.ChatCompletionMessage,
rawResponse: OpenAIClient.Chat.Completions.ChatCompletion
): BaseMessage {
const rawToolCalls: OpenAIToolCall[] | undefined = message.tool_calls as
| OpenAIToolCall[]
| undefined;
switch (message.role) {
case "assistant": {
const toolCalls = [];
const invalidToolCalls = [];
for (const rawToolCall of rawToolCalls ?? []) {
try {
toolCalls.push(parseToolCall(rawToolCall, { returnId: true }));
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} catch (e: any) {
invalidToolCalls.push(makeInvalidToolCall(rawToolCall, e.message));
}
}
const additional_kwargs: Record<string, unknown> = {
function_call: message.function_call,
tool_calls: rawToolCalls,
};
if (this.__includeRawResponse !== undefined) {
additional_kwargs.__raw_response = rawResponse;
}
const response_metadata: Record<string, unknown> | undefined = {
model_name: rawResponse.model,
...(rawResponse.system_fingerprint
? {
usage: { ...rawResponse.usage },
system_fingerprint: rawResponse.system_fingerprint,
}
: {}),
};

if (message.audio) {
additional_kwargs.audio = message.audio;
}

return new AIMessage({
content: message.content || "",
tool_calls: toolCalls,
invalid_tool_calls: invalidToolCalls,
additional_kwargs,
response_metadata,
id: rawResponse.id,
});
}
default:
return new ChatMessage(
message.content || "",
message.role ?? "unknown"
);
}
}

protected _convertOpenAIDeltaToBaseMessageChunk(
// eslint-disable-next-line @typescript-eslint/no-explicit-any
delta: Record<string, any>,
rawResponse: OpenAIClient.Chat.Completions.ChatCompletionChunk,
defaultRole?: OpenAIRoleEnum
) {
const role = delta.role ?? defaultRole;
const content = delta.content ?? "";
let additional_kwargs: Record<string, unknown>;
if (delta.function_call) {
additional_kwargs = {
function_call: delta.function_call,
};
} else if (delta.tool_calls) {
additional_kwargs = {
tool_calls: delta.tool_calls,
};
} else {
additional_kwargs = {};
}
if (this.__includeRawResponse) {
additional_kwargs.__raw_response = rawResponse;
}

if (delta.audio) {
additional_kwargs.audio = {
...delta.audio,
index: rawResponse.choices[0].index,
};
}

const response_metadata = { usage: { ...rawResponse.usage } };
if (role === "user") {
return new HumanMessageChunk({ content, response_metadata });
} else if (role === "assistant") {
const toolCallChunks: ToolCallChunk[] = [];
if (Array.isArray(delta.tool_calls)) {
for (const rawToolCall of delta.tool_calls) {
toolCallChunks.push({
name: rawToolCall.function?.name,
args: rawToolCall.function?.arguments,
id: rawToolCall.id,
index: rawToolCall.index,
type: "tool_call_chunk",
});
}
}
return new AIMessageChunk({
content,
tool_call_chunks: toolCallChunks,
additional_kwargs,
id: rawResponse.id,
response_metadata,
});
} else if (role === "system") {
return new SystemMessageChunk({ content, response_metadata });
} else if (role === "developer") {
return new SystemMessageChunk({
content,
response_metadata,
additional_kwargs: {
__openai_role__: "developer",
},
});
} else if (role === "function") {
return new FunctionMessageChunk({
content,
additional_kwargs,
name: delta.name,
response_metadata,
});
} else if (role === "tool") {
return new ToolMessageChunk({
content,
additional_kwargs,
tool_call_id: delta.tool_call_id,
response_metadata,
});
} else {
return new ChatMessageChunk({ content, role, response_metadata });
}
}

/** @ignore */
_identifyingParams(): Omit<
OpenAIClient.Chat.ChatCompletionCreateParams,
Expand Down Expand Up @@ -1335,11 +1336,10 @@ export class ChatOpenAI<
if (!delta) {
continue;
}
const chunk = _convertDeltaToMessageChunk(
const chunk = this._convertOpenAIDeltaToBaseMessageChunk(
delta,
data,
defaultRole,
this.__includeRawResponse
defaultRole
);
defaultRole = delta.role ?? defaultRole;
const newTokenIndices = {
Expand Down Expand Up @@ -1576,10 +1576,9 @@ export class ChatOpenAI<
const text = part.message?.content ?? "";
const generation: ChatGeneration = {
text,
message: openAIResponseToChatMessage(
message: this._convertOpenAIChatCompletionMessageToBaseMessage(
part.message ?? { role: "assistant" },
data,
this.__includeRawResponse
data
),
};
generation.generationInfo = {
Expand Down
58 changes: 58 additions & 0 deletions libs/langchain-openai/src/tests/chat_models.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ import { CallbackManager } from "@langchain/core/callbacks/manager";
import { NewTokenIndices } from "@langchain/core/callbacks/base";
import { InMemoryCache } from "@langchain/core/caches";
import { concat } from "@langchain/core/utils/stream";
import {
ChatCompletion,
ChatCompletionChunk,
ChatCompletionMessage,
} from "openai/resources/index.mjs";
import { ChatOpenAI } from "../chat_models.js";

// Save the original value of the 'LANGCHAIN_CALLBACKS_BACKGROUND' environment variable
Expand Down Expand Up @@ -1227,3 +1232,56 @@ test("Allows developer messages with o1", async () => {
]);
expect(res.content).toEqual("testing");
});

test.skip("Allow overriding", async () => {
class ChatDeepSeek extends ChatOpenAI {
protected override _convertOpenAIDeltaToBaseMessageChunk(
delta: Record<string, any>,
rawResponse: ChatCompletionChunk,
defaultRole?:
| "function"
| "user"
| "system"
| "developer"
| "assistant"
| "tool"
) {
const messageChunk = super._convertOpenAIDeltaToBaseMessageChunk(
delta,
rawResponse,
defaultRole
);
messageChunk.additional_kwargs.reasoning_content =
delta.reasoning_content;
return messageChunk;
}

protected override _convertOpenAIChatCompletionMessageToBaseMessage(
message: ChatCompletionMessage,
rawResponse: ChatCompletion
) {
const langChainMessage =
super._convertOpenAIChatCompletionMessageToBaseMessage(
message,
rawResponse
);
langChainMessage.additional_kwargs.reasoning_content = (
message as any
).reasoning_content;
return langChainMessage;
}
}
const model = new ChatDeepSeek({
model: "deepseek-reasoner",
configuration: {
baseURL: "https://api.deepseek.com",
},
apiKey: process.env.DEEPSEEK_API_KEY,
});
const res = await model.invoke("what color is the sky?");
console.log(res);
const stream = await model.stream("what color is the sky?");
for await (const chunk of stream) {
console.log(chunk);
}
});
Loading