Skip to content

Commit

Permalink
Fix build
Browse files Browse the repository at this point in the history
  • Loading branch information
jacoblee93 committed Apr 9, 2024
1 parent 79bb8c1 commit afa1bfe
Show file tree
Hide file tree
Showing 10 changed files with 326 additions and 308 deletions.
155 changes: 112 additions & 43 deletions langchain-core/src/messages/ai.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import { parsePartialJson } from "../output_parsers/json.js";
import { parsePartialJson } from "../utils/json.js";
import {
BaseMessage,
BaseMessageChunk,
mergeContent,
_mergeDicts,
type MessageType,
BaseMessageFields,
_mergeLists,
} from "./base.js";
import {
InvalidToolCall,
Expand All @@ -15,40 +16,73 @@ import {
} from "./tool.js";

export type AIMessageFields = BaseMessageFields & {
tool_calls?: (ToolCall | InvalidToolCall)[];
tool_calls?: ToolCall[];
invalid_tool_calls?: InvalidToolCall[];
};

/**
* Represents an AI message in a conversation.
*/
export class AIMessage extends BaseMessage {
tool_calls?: (ToolCall | InvalidToolCall)[];
tool_calls: ToolCall[] = [];

invalid_tool_calls: InvalidToolCall[] = [];

get lc_aliases(): Record<string, string> {
// exclude snake case conversion to pascal case
return {
...super.lc_aliases,
tool_calls: "tool_calls",
invalid_tool_calls: "invalid_tool_calls",
};
}

constructor(fields: string | AIMessageFields) {
let initParams;
if (typeof fields === "string") {
super(fields);
return;
}
try {
const rawToolCalls = fields.additional_kwargs?.tool_calls;
const toolCalls = fields.tool_calls;
if (rawToolCalls !== undefined && toolCalls === undefined) {
// eslint-disable-next-line no-param-reassign
fields.tool_calls = defaultToolCallParser(rawToolCalls ?? []);
initParams = { content: fields, tool_calls: [], invalid_tool_calls: [] };
} else {
initParams = fields;
const rawToolCalls = initParams.additional_kwargs?.tool_calls;
const toolCalls = initParams.tool_calls;
if (
rawToolCalls !== undefined &&
rawToolCalls.length > 0 &&
(toolCalls === undefined || toolCalls.length === 0)
) {
console.warn(
[
"New LangChain packages are available that more efficiently handle ",
"tool calling. Please upgrade your packages to versions that set ",
"message tool calls. e.g., `yarn add @langchain/anthropic`, ",
"yarn add @langchain/openai`, etc.",
].join("\n")
);
}
try {
if (rawToolCalls !== undefined && toolCalls === undefined) {
const [toolCalls, invalidToolCalls] =
defaultToolCallParser(rawToolCalls);
initParams.tool_calls = toolCalls ?? [];
initParams.invalid_tool_calls = invalidToolCalls ?? [];
} else {
initParams.tool_calls = initParams.tool_calls ?? [];
initParams.invalid_tool_calls = initParams.invalid_tool_calls ?? [];
}
} catch (e) {
// Do nothing if parsing fails
initParams.tool_calls = [];
initParams.invalid_tool_calls = [];
}
} catch (e) {
// Do nothing if parsing fails
}
super(fields);
this.tool_calls = fields.tool_calls;
// Sadly, TypeScript only allows super() calls at root if the class has
// properties with initializers, so we have to check types twice.
super(initParams);
if (typeof initParams !== "string") {
this.tool_calls = initParams.tool_calls ?? this.tool_calls;
this.invalid_tool_calls =
initParams.invalid_tool_calls ?? this.invalid_tool_calls;
}
}

static lc_name() {
Expand All @@ -69,48 +103,75 @@ export type AIMessageChunkFields = AIMessageFields & {
* other AI message chunks.
*/
export class AIMessageChunk extends BaseMessageChunk {
// Must redeclare "tool_calls" field due to lack of support for multiple inhertiance.
tool_calls?: (ToolCall | InvalidToolCall)[];
// Must redeclare tool call fields since there is no multiple inhertiance in JS.
tool_calls: ToolCall[] = [];

tool_call_chunks?: ToolCallChunk[];
invalid_tool_calls: InvalidToolCall[] = [];

constructor(fields: AIMessageChunkFields) {
if (fields.tool_calls !== undefined) {
throw new Error(
`"tool_calls" cannot be set directly on AIMessageChunk, it is derived from "tool_call_chunks".`
);
}
if (
fields.tool_call_chunks === undefined ||
fields.tool_call_chunks.length === 0
) {
super({ tool_calls: fields.tool_call_chunks, ...fields });
tool_call_chunks: ToolCallChunk[] = [];

constructor(fields: string | AIMessageChunkFields) {
let initParams: AIMessageChunkFields;
if (typeof fields === "string") {
initParams = {
content: fields,
tool_calls: [],
invalid_tool_calls: [],
tool_call_chunks: [],
};
} else if (fields.tool_call_chunks === undefined) {
initParams = {
...fields,
tool_calls: [],
invalid_tool_calls: [],
tool_call_chunks: [],
};
} else {
// eslint-disable-next-line no-param-reassign
fields.tool_calls = fields.tool_call_chunks.map((toolCallChunk) => {
const toolCalls: ToolCall[] = [];
const invalidToolCalls: InvalidToolCall[] = [];
for (const toolCallChunk of fields.tool_call_chunks) {
let parsedArgs = {};
try {
parsedArgs = parsePartialJson(toolCallChunk.args ?? "{}") ?? {};
if (typeof parsedArgs !== "object" || Array.isArray(parsedArgs)) {
throw new Error("Malformed tool call chunk args.");
}
toolCalls.push({
name: toolCallChunk.name ?? "",
args: parsedArgs,
id: toolCallChunk.id,
});
} catch (e) {
// Do nothing if parsing fails
invalidToolCalls.push({
name: toolCallChunk.name,
args: toolCallChunk.args,
id: toolCallChunk.id,
error: "Malformed args.",
});
}
return new ToolCall({
name: toolCallChunk.name ?? "",
args: parsedArgs,
index: toolCallChunk.index,
id: toolCallChunk.id,
});
});
}
initParams = {
...fields,
tool_calls: toolCalls,
invalid_tool_calls: invalidToolCalls,
};
}
super(fields);
this.tool_call_chunks = fields.tool_call_chunks;
this.tool_calls = fields.tool_calls;
// Sadly, TypeScript only allows super() calls at root if the class has
// properties with initializers, so we have to check types twice.
super(initParams);
this.tool_call_chunks =
initParams?.tool_call_chunks ?? this.tool_call_chunks;
this.tool_calls = initParams?.tool_calls ?? this.tool_calls;
this.invalid_tool_calls =
initParams?.invalid_tool_calls ?? this.invalid_tool_calls;
}

get lc_aliases(): Record<string, string> {
// exclude snake case conversion to pascal case
return {
...super.lc_aliases,
tool_calls: "tool_calls",
invalid_tool_calls: "invalid_tool_calls",
tool_call_chunks: "tool_call_chunks",
};
}
Expand All @@ -134,11 +195,19 @@ export class AIMessageChunk extends BaseMessageChunk {
this.response_metadata,
chunk.response_metadata
),
tool_call_chunks: [],
};
if (
this.tool_call_chunks !== undefined ||
chunk.tool_call_chunks !== undefined
) {
const rawToolCalls = _mergeLists(
this.tool_call_chunks,
chunk.tool_call_chunks
);
if (rawToolCalls !== undefined && rawToolCalls.length > 0) {
combinedFields.tool_call_chunks = rawToolCalls;
}
}
return new AIMessageChunk(combinedFields);
}
Expand Down
25 changes: 1 addition & 24 deletions langchain-core/src/messages/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ export interface FunctionCall {

/**
* @deprecated
* Use the class with the same name in "@langchain/core/messages/tools" instead.
* Import as "OpenAIToolCall" instead
*/
export interface ToolCall {
/**
Expand Down Expand Up @@ -209,29 +209,6 @@ export abstract class BaseMessage
.kwargs as StoredMessageData,
};
}

// toChunk(): BaseMessageChunk {
// const type = this._getType();
// if (type === "human") {
// // eslint-disable-next-line @typescript-eslint/no-use-before-define
// return new HumanMessageChunk({ ...this });
// } else if (type === "ai") {
// // eslint-disable-next-line @typescript-eslint/no-use-before-define
// return new AIMessageChunk({ ...this });
// } else if (type === "system") {
// // eslint-disable-next-line @typescript-eslint/no-use-before-define
// return new SystemMessageChunk({ ...this });
// } else if (type === "function") {
// // eslint-disable-next-line @typescript-eslint/no-use-before-define
// return new FunctionMessageChunk({ ...this });
// // eslint-disable-next-line @typescript-eslint/no-use-before-define
// } else if (ChatMessage.isInstance(this)) {
// // eslint-disable-next-line @typescript-eslint/no-use-before-define
// return new ChatMessageChunk({ ...this });
// } else {
// throw new Error("Unknown message type.");
// }
// }
}

// TODO: Deprecate when SDK typing is updated
Expand Down
Loading

0 comments on commit afa1bfe

Please sign in to comment.