Skip to content

Commit

Permalink
ensure name/id fields are only yielded once for streaming tool calls
Browse files Browse the repository at this point in the history
  • Loading branch information
bracesproul committed Jul 25, 2024
1 parent 8dda83f commit 203f366
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 17 deletions.
100 changes: 83 additions & 17 deletions libs/langchain-groq/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import {
ToolMessage,
OpenAIToolCall,
isAIMessage,
BaseMessageChunk,
} from "@langchain/core/messages";
import {
ChatGeneration,
Expand Down Expand Up @@ -208,7 +209,8 @@ function groqResponseToChatMessage(
}

function _convertDeltaToolCallToToolCallChunk(
toolCalls?: ChatCompletionsAPI.ChatCompletionChunk.Choice.Delta.ToolCall[]
toolCalls?: ChatCompletionsAPI.ChatCompletionChunk.Choice.Delta.ToolCall[],
index?: number
): ToolCallChunk[] | undefined {
if (!toolCalls?.length) return undefined;

Expand All @@ -217,13 +219,23 @@ function _convertDeltaToolCallToToolCallChunk(
name: tc.function?.name,
args: tc.function?.arguments,
type: "tool_call_chunk",
index,
}));
}

function _convertDeltaToMessageChunk(
// eslint-disable-next-line @typescript-eslint/no-explicit-any
delta: Record<string, any>
) {
delta: Record<string, any>,
index: number
): {
message: BaseMessageChunk;
toolCallData?: {
id: string;
name: string;
index: number;
type: "tool_call_chunk";
}[];
} {
const { role } = delta;
const content = delta.content ?? "";
let additional_kwargs;
Expand All @@ -239,17 +251,43 @@ function _convertDeltaToMessageChunk(
additional_kwargs = {};
}
if (role === "user") {
return new HumanMessageChunk({ content });
return {
message: new HumanMessageChunk({ content }),
};
} else if (role === "assistant") {
return new AIMessageChunk({
content,
additional_kwargs,
tool_call_chunks: _convertDeltaToolCallToToolCallChunk(delta.tool_calls),
});
const toolCallChunks = _convertDeltaToolCallToToolCallChunk(
delta.tool_calls,
index
);
return {
message: new AIMessageChunk({
content,
additional_kwargs,
tool_call_chunks: toolCallChunks
? toolCallChunks.map((tc) => ({
type: tc.type,
args: tc.args,
index: tc.index,
}))
: undefined,
}),
toolCallData: toolCallChunks
? toolCallChunks.map((tc) => ({
id: tc.id ?? "",
name: tc.name ?? "",
index: tc.index ?? index,
type: "tool_call_chunk",
}))
: undefined,
};
} else if (role === "system") {
return new SystemMessageChunk({ content });
return {
message: new SystemMessageChunk({ content }),
};
} else {
return new ChatMessageChunk({ content, role });
return {
message: new ChatMessageChunk({ content, role }),
};
}
}

Expand Down Expand Up @@ -423,6 +461,12 @@ export class ChatGroq extends BaseChatModel<
}
);
let role = "";
const toolCall: {
id: string;
name: string;
index: number;
type: "tool_call_chunk";
}[] = [];
for await (const data of response) {
const choice = data?.choices[0];
if (!choice) {
Expand All @@ -433,13 +477,34 @@ export class ChatGroq extends BaseChatModel<
if (choice.delta?.role) {
role = choice.delta.role;
}

const { message, toolCallData } = _convertDeltaToMessageChunk(
{
...choice.delta,
role,
} ?? {},
choice.index
);

if (toolCallData) {
// First, ensure the ID is not already present in toolCall
const newToolCallData = toolCallData.filter((tc) =>
toolCall.every((t) => t.id !== tc.id)
);
toolCall.push(...newToolCallData);

// Yield here, ensuring the ID and name fields are only yielded once.
yield new ChatGenerationChunk({
message: new AIMessageChunk({
content: "",
tool_call_chunks: newToolCallData,
}),
text: "",
});
}

const chunk = new ChatGenerationChunk({
message: _convertDeltaToMessageChunk(
{
...choice.delta,
role,
} ?? {}
),
message,
text: choice.delta.content ?? "",
generationInfo: {
finishReason: choice.finish_reason,
Expand All @@ -448,6 +513,7 @@ export class ChatGroq extends BaseChatModel<
yield chunk;
void runManager?.handleLLMNewToken(chunk.text ?? "");
}

if (options.signal?.aborted) {
throw new Error("AbortError");
}
Expand Down
1 change: 1 addition & 0 deletions libs/langchain-groq/src/tests/chat_models.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -239,4 +239,5 @@ test("Groq can stream tool calls", async () => {

expect(finalMessage.tool_calls?.[0].name).toBe("get_current_weather");
expect(finalMessage.tool_calls?.[0].args).toHaveProperty("location");
expect(finalMessage.tool_calls?.[0].id).toBeDefined();
});

0 comments on commit 203f366

Please sign in to comment.