From a614d1dbb38f13366940d107bcce3bede6d5591f Mon Sep 17 00:00:00 2001 From: Brace Sproul Date: Thu, 25 Jul 2024 15:10:20 -0700 Subject: [PATCH] groq[minor]: Implement streaming tool calls (#6203) * implemented and added test * chore: lint files * ayrn * chore: lint files * ensure name/id fields are only yielded once for streaming tool calls --- libs/langchain-groq/package.json | 4 +- libs/langchain-groq/src/chat_models.ts | 217 +++++++++++------- .../src/tests/chat_models.int.test.ts | 48 +++- .../tests/chat_models.standard.int.test.ts | 2 +- yarn.lock | 62 +---- 5 files changed, 192 insertions(+), 141 deletions(-) diff --git a/libs/langchain-groq/package.json b/libs/langchain-groq/package.json index d4ee60e3c8f4..29ab170beb16 100644 --- a/libs/langchain-groq/package.json +++ b/libs/langchain-groq/package.json @@ -35,9 +35,9 @@ "author": "LangChain", "license": "MIT", "dependencies": { - "@langchain/core": ">=0.2.16 <0.3.0", + "@langchain/core": ">=0.2.18 <0.3.0", "@langchain/openai": "~0.2.4", - "groq-sdk": "^0.3.2", + "groq-sdk": "^0.5.0", "zod": "^3.22.4", "zod-to-json-schema": "^3.22.5" }, diff --git a/libs/langchain-groq/src/chat_models.ts b/libs/langchain-groq/src/chat_models.ts index ca7c302cfd7f..b2291dc552ce 100644 --- a/libs/langchain-groq/src/chat_models.ts +++ b/libs/langchain-groq/src/chat_models.ts @@ -8,6 +8,8 @@ import { LangSmithParams, type BaseChatModelParams, } from "@langchain/core/language_models/chat_models"; +import * as ChatCompletionsAPI from "groq-sdk/resources/chat/completions"; +import * as CompletionsAPI from "groq-sdk/resources/completions"; import { AIMessage, AIMessageChunk, @@ -19,6 +21,7 @@ import { ToolMessage, OpenAIToolCall, isAIMessage, + BaseMessageChunk, } from "@langchain/core/messages"; import { ChatGeneration, @@ -32,7 +35,6 @@ import { } from "@langchain/openai"; import { isZodSchema } from "@langchain/core/utils/types"; import Groq from "groq-sdk"; -import { ChatCompletionChunk } from "groq-sdk/lib/chat_completions_ext"; import { ChatCompletion, ChatCompletionCreateParams, @@ -146,8 +148,8 @@ export function messageToGroqRole(message: BaseMessage): GroqRoleEnum { function convertMessagesToGroqParams( messages: BaseMessage[] -): Array { - return messages.map((message): ChatCompletion.Choice.Message => { +): Array { + return messages.map((message): ChatCompletionsAPI.ChatCompletionMessage => { if (typeof message.content !== "string") { throw new Error("Non string message content not supported"); } @@ -172,12 +174,12 @@ function convertMessagesToGroqParams( completionParam.tool_call_id = (message as ToolMessage).tool_call_id; } } - return completionParam as ChatCompletion.Choice.Message; + return completionParam as ChatCompletionsAPI.ChatCompletionMessage; }); } function groqResponseToChatMessage( - message: ChatCompletion.Choice.Message + message: ChatCompletionsAPI.ChatCompletionMessage ): BaseMessage { const rawToolCalls: OpenAIToolCall[] | undefined = message.tool_calls as | OpenAIToolCall[] @@ -206,10 +208,34 @@ function groqResponseToChatMessage( } } +function _convertDeltaToolCallToToolCallChunk( + toolCalls?: ChatCompletionsAPI.ChatCompletionChunk.Choice.Delta.ToolCall[], + index?: number +): ToolCallChunk[] | undefined { + if (!toolCalls?.length) return undefined; + + return toolCalls.map((tc) => ({ + id: tc.id, + name: tc.function?.name, + args: tc.function?.arguments, + type: "tool_call_chunk", + index, + })); +} + function _convertDeltaToMessageChunk( // eslint-disable-next-line @typescript-eslint/no-explicit-any - delta: Record -) { + delta: Record, + index: number +): { + message: BaseMessageChunk; + toolCallData?: { + id: string; + name: string; + index: number; + type: "tool_call_chunk"; + }[]; +} { const { role } = delta; const content = delta.content ?? ""; let additional_kwargs; @@ -225,13 +251,43 @@ function _convertDeltaToMessageChunk( additional_kwargs = {}; } if (role === "user") { - return new HumanMessageChunk({ content }); + return { + message: new HumanMessageChunk({ content }), + }; } else if (role === "assistant") { - return new AIMessageChunk({ content, additional_kwargs }); + const toolCallChunks = _convertDeltaToolCallToToolCallChunk( + delta.tool_calls, + index + ); + return { + message: new AIMessageChunk({ + content, + additional_kwargs, + tool_call_chunks: toolCallChunks + ? toolCallChunks.map((tc) => ({ + type: tc.type, + args: tc.args, + index: tc.index, + })) + : undefined, + }), + toolCallData: toolCallChunks + ? toolCallChunks.map((tc) => ({ + id: tc.id ?? "", + name: tc.name ?? "", + index: tc.index ?? index, + type: "tool_call_chunk", + })) + : undefined, + }; } else if (role === "system") { - return new SystemMessageChunk({ content }); + return { + message: new SystemMessageChunk({ content }), + }; } else { - return new ChatMessageChunk({ content, role }); + return { + message: new ChatMessageChunk({ content, role }), + }; } } @@ -322,8 +378,8 @@ export class ChatGroq extends BaseChatModel< ls_provider: "groq", ls_model_name: this.model, ls_model_type: "chat", - ls_temperature: params.temperature, - ls_max_tokens: params.max_tokens, + ls_temperature: params.temperature ?? this.temperature, + ls_max_tokens: params.max_tokens ?? this.maxTokens, ls_stop: options.stop, }; } @@ -331,7 +387,7 @@ export class ChatGroq extends BaseChatModel< async completionWithRetry( request: ChatCompletionCreateParamsStreaming, options?: OpenAICoreRequestOptions - ): Promise>; + ): Promise>; async completionWithRetry( request: ChatCompletionCreateParamsNonStreaming, @@ -341,7 +397,9 @@ export class ChatGroq extends BaseChatModel< async completionWithRetry( request: ChatCompletionCreateParams, options?: OpenAICoreRequestOptions - ): Promise | ChatCompletion> { + ): Promise< + AsyncIterable | ChatCompletion + > { return this.caller.call(async () => this.client.chat.completions.create(request, options) ); @@ -391,76 +449,73 @@ export class ChatGroq extends BaseChatModel< ): AsyncGenerator { const params = this.invocationParams(options); const messagesMapped = convertMessagesToGroqParams(messages); - if (options.tools !== undefined && options.tools.length > 0) { - const result = await this._generateNonStreaming( - messages, - options, - runManager - ); - const generationMessage = result.generations[0].message as AIMessage; - if ( - generationMessage === undefined || - typeof generationMessage.content !== "string" - ) { - throw new Error("Could not parse Groq output."); + const response = await this.completionWithRetry( + { + ...params, + messages: messagesMapped, + stream: true, + }, + { + signal: options?.signal, + headers: options?.headers, } - const toolCallChunks: ToolCallChunk[] | undefined = - generationMessage.tool_calls?.map((toolCall, i) => ({ - name: toolCall.name, - args: JSON.stringify(toolCall.args), - id: toolCall.id, - index: i, - type: "tool_call_chunk", - })); - yield new ChatGenerationChunk({ - message: new AIMessageChunk({ - content: generationMessage.content, - additional_kwargs: generationMessage.additional_kwargs, - tool_call_chunks: toolCallChunks, - }), - text: generationMessage.content, - }); - } else { - const response = await this.completionWithRetry( - { - ...params, - messages: messagesMapped, - stream: true, - }, + ); + let role = ""; + const toolCall: { + id: string; + name: string; + index: number; + type: "tool_call_chunk"; + }[] = []; + for await (const data of response) { + const choice = data?.choices[0]; + if (!choice) { + continue; + } + // The `role` field is populated in the first delta of the response + // but is not present in subsequent deltas. Extract it when available. + if (choice.delta?.role) { + role = choice.delta.role; + } + + const { message, toolCallData } = _convertDeltaToMessageChunk( { - signal: options?.signal, - headers: options?.headers, - } + ...choice.delta, + role, + } ?? {}, + choice.index ); - let role = ""; - for await (const data of response) { - const choice = data?.choices[0]; - if (!choice) { - continue; - } - // The `role` field is populated in the first delta of the response - // but is not present in subsequent deltas. Extract it when available. - if (choice.delta?.role) { - role = choice.delta.role; - } - const chunk = new ChatGenerationChunk({ - message: _convertDeltaToMessageChunk( - { - ...choice.delta, - role, - } ?? {} - ), - text: choice.delta.content ?? "", - generationInfo: { - finishReason: choice.finish_reason, - }, + + if (toolCallData) { + // First, ensure the ID is not already present in toolCall + const newToolCallData = toolCallData.filter((tc) => + toolCall.every((t) => t.id !== tc.id) + ); + toolCall.push(...newToolCallData); + + // Yield here, ensuring the ID and name fields are only yielded once. + yield new ChatGenerationChunk({ + message: new AIMessageChunk({ + content: "", + tool_call_chunks: newToolCallData, + }), + text: "", }); - yield chunk; - void runManager?.handleLLMNewToken(chunk.text ?? ""); - } - if (options.signal?.aborted) { - throw new Error("AbortError"); } + + const chunk = new ChatGenerationChunk({ + message, + text: choice.delta.content ?? "", + generationInfo: { + finishReason: choice.finish_reason, + }, + }); + yield chunk; + void runManager?.handleLLMNewToken(chunk.text ?? ""); + } + + if (options.signal?.aborted) { + throw new Error("AbortError"); } } @@ -518,7 +573,7 @@ export class ChatGroq extends BaseChatModel< completion_tokens: completionTokens, prompt_tokens: promptTokens, total_tokens: totalTokens, - } = data.usage as ChatCompletion.Usage; + } = data.usage as CompletionsAPI.CompletionUsage; if (completionTokens) { tokenUsage.completionTokens = diff --git a/libs/langchain-groq/src/tests/chat_models.int.test.ts b/libs/langchain-groq/src/tests/chat_models.int.test.ts index 20b6d482356b..07817b8a800d 100644 --- a/libs/langchain-groq/src/tests/chat_models.int.test.ts +++ b/libs/langchain-groq/src/tests/chat_models.int.test.ts @@ -1,5 +1,13 @@ import { test } from "@jest/globals"; -import { AIMessage, HumanMessage, ToolMessage } from "@langchain/core/messages"; +import { + AIMessage, + AIMessageChunk, + HumanMessage, + ToolMessage, +} from "@langchain/core/messages"; +import { tool } from "@langchain/core/tools"; +import { z } from "zod"; +import { concat } from "@langchain/core/utils/stream"; import { ChatGroq } from "../chat_models.js"; test("invoke", async () => { @@ -197,3 +205,41 @@ test("Few shotting with tool calls", async () => { // console.log(res); expect(res.content).toContain("24"); }); + +test("Groq can stream tool calls", async () => { + const model = new ChatGroq({ + model: "llama-3.1-70b-versatile", + temperature: 0, + }); + + const weatherTool = tool((_) => "The temperature is 24 degrees with hail.", { + name: "get_current_weather", + schema: z.object({ + location: z + .string() + .describe("The location to get the current weather for."), + }), + description: "Get the current weather in a given location.", + }); + + const modelWithTools = model.bindTools([weatherTool]); + + const stream = await modelWithTools.stream( + "What is the weather in San Francisco?" + ); + + let finalMessage: AIMessageChunk | undefined; + for await (const chunk of stream) { + finalMessage = !finalMessage ? chunk : concat(finalMessage, chunk); + } + + expect(finalMessage).toBeDefined(); + if (!finalMessage) return; + + expect(finalMessage.tool_calls?.[0]).toBeDefined(); + if (!finalMessage.tool_calls?.[0]) return; + + expect(finalMessage.tool_calls?.[0].name).toBe("get_current_weather"); + expect(finalMessage.tool_calls?.[0].args).toHaveProperty("location"); + expect(finalMessage.tool_calls?.[0].id).toBeDefined(); +}); diff --git a/libs/langchain-groq/src/tests/chat_models.standard.int.test.ts b/libs/langchain-groq/src/tests/chat_models.standard.int.test.ts index 9e1a2774771f..82c4e3c392f8 100644 --- a/libs/langchain-groq/src/tests/chat_models.standard.int.test.ts +++ b/libs/langchain-groq/src/tests/chat_models.standard.int.test.ts @@ -19,7 +19,7 @@ class ChatGroqStandardIntegrationTests extends ChatModelIntegrationTests< chatModelHasToolCalling: true, chatModelHasStructuredOutput: true, constructorArgs: { - model: "mixtral-8x7b-32768", + model: "llama-3.1-70b-versatile", }, }); } diff --git a/yarn.lock b/yarn.lock index 7a03517cb874..58d96d15b878 100644 --- a/yarn.lock +++ b/yarn.lock @@ -11903,7 +11903,7 @@ __metadata: resolution: "@langchain/groq@workspace:libs/langchain-groq" dependencies: "@jest/globals": ^29.5.0 - "@langchain/core": ">=0.2.16 <0.3.0" + "@langchain/core": ">=0.2.18 <0.3.0" "@langchain/openai": "workspace:^" "@langchain/scripts": ~0.0.20 "@langchain/standard-tests": 0.0.0 @@ -11921,7 +11921,7 @@ __metadata: eslint-plugin-import: ^2.27.5 eslint-plugin-no-instanceof: ^1.0.1 eslint-plugin-prettier: ^4.2.1 - groq-sdk: ^0.3.2 + groq-sdk: ^0.5.0 jest: ^29.5.0 jest-environment-node: ^29.6.4 prettier: ^2.8.3 @@ -20930,13 +20930,6 @@ __metadata: languageName: node linkType: hard -"base-64@npm:^0.1.0": - version: 0.1.0 - resolution: "base-64@npm:0.1.0" - checksum: 5a42938f82372ab5392cbacc85a5a78115cbbd9dbef9f7540fa47d78763a3a8bd7d598475f0d92341f66285afd377509851a9bb5c67bbecb89686e9255d5b3eb - languageName: node - linkType: hard - "base-64@npm:^1.0.0": version: 1.0.0 resolution: "base-64@npm:1.0.0" @@ -21746,13 +21739,6 @@ __metadata: languageName: node linkType: hard -"charenc@npm:0.0.2": - version: 0.0.2 - resolution: "charenc@npm:0.0.2" - checksum: 81dcadbe57e861d527faf6dd3855dc857395a1c4d6781f4847288ab23cffb7b3ee80d57c15bba7252ffe3e5e8019db767757ee7975663ad2ca0939bb8fcaf2e5 - languageName: node - linkType: hard - "cheerio-select@npm:^2.1.0": version: 2.1.0 resolution: "cheerio-select@npm:2.1.0" @@ -22977,13 +22963,6 @@ __metadata: languageName: node linkType: hard -"crypt@npm:0.0.2": - version: 0.0.2 - resolution: "crypt@npm:0.0.2" - checksum: baf4c7bbe05df656ec230018af8cf7dbe8c14b36b98726939cef008d473f6fe7a4fad906cfea4062c93af516f1550a3f43ceb4d6615329612c6511378ed9fe34 - languageName: node - linkType: hard - "crypto-js@npm:^4.2.0": version: 4.2.0 resolution: "crypto-js@npm:4.2.0" @@ -24173,16 +24152,6 @@ __metadata: languageName: node linkType: hard -"digest-fetch@npm:^1.3.0": - version: 1.3.0 - resolution: "digest-fetch@npm:1.3.0" - dependencies: - base-64: ^0.1.0 - md5: ^2.3.0 - checksum: 8ebdb4b9ef02b1ac0da532d25c7d08388f2552813dfadabfe7c4630e944bb4a48093b997fc926440a10e1ccf4912f2ce9adcf2d6687b0518dab8480e08f22f9d - languageName: node - linkType: hard - "dingbat-to-unicode@npm:^1.0.1": version: 1.0.1 resolution: "dingbat-to-unicode@npm:1.0.1" @@ -28060,20 +28029,19 @@ __metadata: languageName: node linkType: hard -"groq-sdk@npm:^0.3.2": - version: 0.3.2 - resolution: "groq-sdk@npm:0.3.2" +"groq-sdk@npm:^0.5.0": + version: 0.5.0 + resolution: "groq-sdk@npm:0.5.0" dependencies: "@types/node": ^18.11.18 "@types/node-fetch": ^2.6.4 abort-controller: ^3.0.0 agentkeepalive: ^4.2.1 - digest-fetch: ^1.3.0 form-data-encoder: 1.7.2 formdata-node: ^4.3.2 node-fetch: ^2.6.7 web-streams-polyfill: ^3.2.1 - checksum: 78cdc02ac8e87d5c47c2857def55d14249ee1b698f11d06db01a86227716a3e4e2312224996168f7edee51992862082dd4dfcdfec54b765d698855db9971e525 + checksum: 051ca56e99e4a2440080943c831b109687dd346b24155d3f085113df1ad0639cb95724c14a05611f7314d340db8bf342af425eb11905c97bc6a6948cd7262f04 languageName: node linkType: hard @@ -29143,13 +29111,6 @@ __metadata: languageName: node linkType: hard -"is-buffer@npm:~1.1.6": - version: 1.1.6 - resolution: "is-buffer@npm:1.1.6" - checksum: 4a186d995d8bbf9153b4bd9ff9fd04ae75068fe695d29025d25e592d9488911eeece84eefbd8fa41b8ddcc0711058a71d4c466dcf6f1f6e1d83830052d8ca707 - languageName: node - linkType: hard - "is-callable@npm:^1.1.3, is-callable@npm:^1.1.4, is-callable@npm:^1.2.7": version: 1.2.7 resolution: "is-callable@npm:1.2.7" @@ -32500,17 +32461,6 @@ __metadata: languageName: node linkType: hard -"md5@npm:^2.3.0": - version: 2.3.0 - resolution: "md5@npm:2.3.0" - dependencies: - charenc: 0.0.2 - crypt: 0.0.2 - is-buffer: ~1.1.6 - checksum: a63cacf4018dc9dee08c36e6f924a64ced735b37826116c905717c41cebeb41a522f7a526ba6ad578f9c80f02cb365033ccd67fe186ffbcc1a1faeb75daa9b6e - languageName: node - linkType: hard - "mdast-squeeze-paragraphs@npm:^4.0.0": version: 4.0.0 resolution: "mdast-squeeze-paragraphs@npm:4.0.0"