diff --git a/libs/langchain-google-common/src/utils/gemini.ts b/libs/langchain-google-common/src/utils/gemini.ts index 432620dcc52b..f10fd07ea913 100644 --- a/libs/langchain-google-common/src/utils/gemini.ts +++ b/libs/langchain-google-common/src/utils/gemini.ts @@ -2,7 +2,7 @@ import { v4 as uuidv4 } from "uuid"; import { AIMessage, AIMessageChunk, - AIMessageFields, + AIMessageChunkFields, BaseMessage, BaseMessageChunk, BaseMessageFields, @@ -566,7 +566,7 @@ export function chunkToString(chunk: BaseMessageChunk): string { } export function partToMessageChunk(part: GeminiPart): BaseMessageChunk { - const fields = partsToBaseMessageFields([part]); + const fields = partsToBaseMessageChunkFields([part]); if (typeof fields.content === "string") { return new AIMessageChunk(fields); } else if (fields.content.every((item) => item.type === "text")) { @@ -636,12 +636,13 @@ export function responseToBaseMessageFields( response: GoogleLLMResponse ): BaseMessageFields { const parts = responseToParts(response); - return partsToBaseMessageFields(parts); + return partsToBaseMessageChunkFields(parts); } -export function partsToBaseMessageFields(parts: GeminiPart[]): AIMessageFields { - const fields: AIMessageFields = { +export function partsToBaseMessageChunkFields(parts: GeminiPart[]): AIMessageChunkFields { + const fields: AIMessageChunkFields = { content: partsToMessageContent(parts), + tool_call_chunks: [], tool_calls: [], invalid_tool_calls: [], }; @@ -650,6 +651,13 @@ export function partsToBaseMessageFields(parts: GeminiPart[]): AIMessageFields { if (rawTools.length > 0) { const tools = toolsRawToTools(rawTools); for (const tool of tools) { + fields.tool_call_chunks?.push({ + name: tool.function.name, + args: tool.function.arguments, + id: tool.id, + type: "tool_call_chunk", + }); + try { fields.tool_calls?.push({ name: tool.function.name, @@ -661,7 +669,7 @@ export function partsToBaseMessageFields(parts: GeminiPart[]): AIMessageFields { } catch (e: any) { fields.invalid_tool_calls?.push({ name: tool.function.name, - args: JSON.parse(tool.function.arguments), + args: tool.function.arguments, id: tool.id, error: e.message, type: "invalid_tool_call", @@ -675,6 +683,42 @@ export function partsToBaseMessageFields(parts: GeminiPart[]): AIMessageFields { return fields; } +// export function partsToBaseMessageFields(parts: GeminiPart[]): AIMessageFields { +// const fields: AIMessageFields = { +// content: partsToMessageContent(parts), +// tool_calls: [], +// invalid_tool_calls: [], +// }; + +// const rawTools = partsToToolsRaw(parts); +// if (rawTools.length > 0) { +// const tools = toolsRawToTools(rawTools); +// for (const tool of tools) { +// try { +// fields.tool_calls?.push({ +// name: tool.function.name, +// args: JSON.parse(tool.function.arguments), +// id: tool.id, +// type: "tool_call", +// }); +// // eslint-disable-next-line @typescript-eslint/no-explicit-any +// } catch (e: any) { +// fields.invalid_tool_calls?.push({ +// name: tool.function.name, +// args: JSON.parse(tool.function.arguments), +// id: tool.id, +// error: e.message, +// type: "invalid_tool_call", +// }); +// } +// } +// fields.additional_kwargs = { +// tool_calls: tools, +// }; +// } +// return fields; +// } + export function responseToBaseMessage( response: GoogleLLMResponse ): BaseMessage { diff --git a/libs/langchain-google-vertexai/package.json b/libs/langchain-google-vertexai/package.json index 34a52b83c702..867234a91548 100644 --- a/libs/langchain-google-vertexai/package.json +++ b/libs/langchain-google-vertexai/package.json @@ -70,7 +70,8 @@ "release-it": "^15.10.1", "rollup": "^4.5.2", "ts-jest": "^29.1.0", - "typescript": "<5.2.0" + "typescript": "<5.2.0", + "zod": "^3.22.4" }, "publishConfig": { "access": "public" diff --git a/libs/langchain-google-vertexai/src/tests/chat_models.int.test.ts b/libs/langchain-google-vertexai/src/tests/chat_models.int.test.ts index 054462d7d1c0..3af47cb2f06b 100644 --- a/libs/langchain-google-vertexai/src/tests/chat_models.int.test.ts +++ b/libs/langchain-google-vertexai/src/tests/chat_models.int.test.ts @@ -13,6 +13,9 @@ import { } from "@langchain/core/messages"; import { ChatVertexAI } from "../chat_models.js"; import { GeminiTool } from "../types.js"; +import { tool } from "@langchain/core/tools"; +import { concat } from "@langchain/core/utils/stream"; +import { z } from "zod"; describe("GAuth Chat", () => { test("invoke", async () => { @@ -41,7 +44,7 @@ describe("GAuth Chat", () => { expect(textContent.text).toEqual("2"); */ } catch (e) { - console.error(e); + // console.error(e); throw e; } }); @@ -81,7 +84,7 @@ describe("GAuth Chat", () => { expect(["H", "T"]).toContainEqual(textContent.text); */ } catch (e) { - console.error(e); + // console.error(e); throw e; } }); @@ -108,12 +111,11 @@ describe("GAuth Chat", () => { const lastChunk = resArray[resArray.length - 1]; expect(lastChunk).toBeDefined(); expect(lastChunk._getType()).toEqual("ai"); - const aiChunk = lastChunk as AIMessageChunk; - console.log(aiChunk); - - console.log(JSON.stringify(resArray, null, 2)); + // const aiChunk = lastChunk as AIMessageChunk; + // console.log(aiChunk); + // console.log(JSON.stringify(resArray, null, 2)); } catch (e) { - console.error(e); + // console.error(e); throw e; } }); @@ -209,7 +211,7 @@ describe("GAuth Chat", () => { for await (const chunk of res) { resArray.push(chunk); } - console.log(JSON.stringify(resArray, null, 2)); + // console.log(JSON.stringify(resArray, null, 2)); }); test("withStructuredOutput", async () => { @@ -249,7 +251,7 @@ test("Stream token count usage_metadata", async () => { res = res.concat(chunk); } } - console.log(res); + // console.log(res); expect(res?.usage_metadata).toBeDefined(); if (!res?.usage_metadata) { return; @@ -276,7 +278,7 @@ test("streamUsage excludes token usage", async () => { res = res.concat(chunk); } } - console.log(res); + // console.log(res); expect(res?.usage_metadata).not.toBeDefined(); }); @@ -286,7 +288,7 @@ test("Invoke token count usage_metadata", async () => { maxOutputTokens: 10, }); const res = await model.invoke("Why is the sky blue? Be concise."); - console.log(res); + // console.log(res); expect(res?.usage_metadata).toBeDefined(); if (!res?.usage_metadata) { return; @@ -322,3 +324,36 @@ test("Streaming true constructor param will stream", async () => { expect(totalTokenCount).toBeGreaterThan(1); }); + +test("ChatGoogleGenerativeAI can stream tools", async () => { + const model = new ChatVertexAI({}); + + const weatherTool = tool((_) => { + return "The weather in San Francisco today is 18 degrees and sunny." + }, { + name: "current_weather_tool", + description: "Get the current weather for a given location.", + schema: z.object({ + location: z.string().describe("The location to get the weather for."), + }) + }) + + const modelWithTools = model.bindTools([weatherTool]); + const stream = await modelWithTools.stream("Whats the weather like today in San Francisco?"); + let finalChunk: AIMessageChunk | undefined; + for await (const chunk of stream) { + finalChunk = !finalChunk ? chunk : concat(finalChunk, chunk); + } + + expect(finalChunk).toBeDefined(); + if (!finalChunk) return; + + const toolCalls = finalChunk.tool_calls; + expect(toolCalls).toBeDefined(); + if (!toolCalls) { + throw new Error("tool_calls not in response"); + } + expect(toolCalls.length).toBe(1); + expect(toolCalls[0].name).toBe("current_weather_tool"); + expect(toolCalls[0].args).toHaveProperty("location") +}); \ No newline at end of file diff --git a/libs/langchain-google-vertexai/src/tests/chat_models.standard.int.test.ts b/libs/langchain-google-vertexai/src/tests/chat_models.standard.int.test.ts index c44f36916ddc..60c5b6c421b0 100644 --- a/libs/langchain-google-vertexai/src/tests/chat_models.standard.int.test.ts +++ b/libs/langchain-google-vertexai/src/tests/chat_models.standard.int.test.ts @@ -19,6 +19,7 @@ class ChatVertexAIStandardIntegrationTests extends ChatModelIntegrationTests< Cls: ChatVertexAI, chatModelHasToolCalling: true, chatModelHasStructuredOutput: true, + invokeResponseType: AIMessageChunk, constructorArgs: { model: "gemini-1.5-pro", }, @@ -32,6 +33,14 @@ class ChatVertexAIStandardIntegrationTests extends ChatModelIntegrationTests< "Not implemented." ); } + + async testInvokeMoreComplexTools() { + this.skipTestMessage( + "testInvokeMoreComplexTools", + "ChatVertexAI", + "Google VertexAI does not support tool schemas where the object properties are not defined." + ); + } } const testClass = new ChatVertexAIStandardIntegrationTests(); diff --git a/yarn.lock b/yarn.lock index c5dcede32a3e..94223547d02b 100644 --- a/yarn.lock +++ b/yarn.lock @@ -11695,6 +11695,7 @@ __metadata: rollup: ^4.5.2 ts-jest: ^29.1.0 typescript: <5.2.0 + zod: ^3.22.4 languageName: unknown linkType: soft