From 162d7ea6189d5e90219963b639071178d5ad5e26 Mon Sep 17 00:00:00 2001 From: afirstenberg Date: Tue, 17 Dec 2024 18:35:00 -0500 Subject: [PATCH 1/8] Test over multiple models. Possible tweaks for functions under gemini-2.0-flash-exp --- .../src/tests/chat_models.int.test.ts | 70 +++++++++++++++---- 1 file changed, 57 insertions(+), 13 deletions(-) diff --git a/libs/langchain-google-vertexai/src/tests/chat_models.int.test.ts b/libs/langchain-google-vertexai/src/tests/chat_models.int.test.ts index ddcdf579a394..74e743765c9d 100644 --- a/libs/langchain-google-vertexai/src/tests/chat_models.int.test.ts +++ b/libs/langchain-google-vertexai/src/tests/chat_models.int.test.ts @@ -59,18 +59,43 @@ const calculatorTool = tool((_) => "no-op", { }), }); -describe("GAuth Gemini Chat", () => { +/* + * Which models do we want to run the test suite against? + */ +const testGeminiModelNames = [ + ["gemini-1.5-pro-002"], + ["gemini-1.5-flash-002"], + ["gemini-2.0-flash-exp"], +] + +/* + * Some models may have usage quotas still. + * For those models, set how long (in millis) to wait in between each test. + */ +const testGeminiModelDelay: Record = { + "gemini-2.0-flash-exp": 5000, +} + +describe.each(testGeminiModelNames)("GAuth Gemini Chat (%s)", (modelName) => { let recorder: GoogleRequestRecorder; let callbacks: BaseCallbackHandler[]; - beforeEach(() => { + beforeEach(async () => { recorder = new GoogleRequestRecorder(); callbacks = [recorder, new GoogleRequestLogger()]; + + const delay = testGeminiModelDelay[modelName] ?? 0; + if (delay) { + console.log(`Delaying for ${delay}ms`) + // eslint-disable-next-line no-promise-executor-return + await new Promise(resolve => setTimeout(resolve,delay)); + } }); test("invoke", async () => { const model = new ChatVertexAI({ callbacks, + modelName, }); const res = await model.invoke("What is 1 + 1?"); expect(res).toBeDefined(); @@ -84,8 +109,10 @@ describe("GAuth Gemini Chat", () => { expect(text).toMatch(/(1 + 1 (equals|is|=) )?2.? ?/); }); - test("generate", async () => { - const model = new ChatVertexAI(); + test(`generate`, async () => { + const model = new ChatVertexAI({ + modelName, + }); const messages: BaseMessage[] = [ new SystemMessage( "You will reply to all requests to flip a coin with either H, indicating heads, or T, indicating tails." @@ -103,12 +130,13 @@ describe("GAuth Gemini Chat", () => { expect(typeof aiMessage.content).toBe("string"); const text = aiMessage.content as string; - expect(["H", "T"]).toContainEqual(text); + expect(["H", "T"]).toContainEqual(text.trim()); }); test("stream", async () => { const model = new ChatVertexAI({ callbacks, + modelName, }); const input: BaseLanguageModelInput = new ChatPromptValue([ new SystemMessage( @@ -153,7 +181,11 @@ describe("GAuth Gemini Chat", () => { ], }, ]; - const model = new ChatVertexAI().bind({ tools }); + const model = new ChatVertexAI({ + modelName, + }).bind({ + tools, + }); const result = await model.invoke("Run a test on the cobalt project"); expect(result).toHaveProperty("content"); expect(result.content).toBe(""); @@ -197,7 +229,11 @@ describe("GAuth Gemini Chat", () => { ], }, ]; - const model = new ChatVertexAI().bind({ tools }); + const model = new ChatVertexAI({ + modelName, + }).bind({ + tools, + }); const toolResult = { testPassed: true, }; @@ -241,7 +277,9 @@ describe("GAuth Gemini Chat", () => { required: ["location"], }, }; - const model = new ChatVertexAI().withStructuredOutput(tool); + const model = new ChatVertexAI({ + modelName, + }).withStructuredOutput(tool); const result = await model.invoke("What is the weather in Paris?"); expect(result).toHaveProperty("location"); }); @@ -275,7 +313,7 @@ describe("GAuth Gemini Chat", () => { resolvers: [resolver], }); const model = new ChatGoogle({ - modelName: "gemini-1.5-flash", + modelName, apiConfig: { mediaManager, }, @@ -320,6 +358,7 @@ describe("GAuth Gemini Chat", () => { const model = new ChatVertexAI({ temperature: 0, maxOutputTokens: 10, + modelName, }); let res: AIMessageChunk | null = null; for await (const chunk of await model.stream( @@ -347,6 +386,7 @@ describe("GAuth Gemini Chat", () => { const model = new ChatVertexAI({ temperature: 0, streamUsage: false, + modelName, }); let res: AIMessageChunk | null = null; for await (const chunk of await model.stream( @@ -366,6 +406,7 @@ describe("GAuth Gemini Chat", () => { const model = new ChatVertexAI({ temperature: 0, maxOutputTokens: 10, + modelName, }); const res = await model.invoke("Why is the sky blue? Be concise."); // console.log(res); @@ -384,6 +425,7 @@ describe("GAuth Gemini Chat", () => { const modelWithStreaming = new ChatVertexAI({ maxOutputTokens: 50, streaming: true, + modelName, }); let totalTokenCount = 0; @@ -407,7 +449,7 @@ describe("GAuth Gemini Chat", () => { test("Can force a model to invoke a tool", async () => { const model = new ChatVertexAI({ - model: "gemini-1.5-pro", + modelName, }); const modelWithTools = model.bind({ tools: [calculatorTool, weatherTool], @@ -425,8 +467,10 @@ describe("GAuth Gemini Chat", () => { expect(result.tool_calls?.[0].args).toHaveProperty("expression"); }); - test("ChatGoogleGenerativeAI can stream tools", async () => { - const model = new ChatVertexAI({}); + test(`stream tools`, async () => { + const model = new ChatVertexAI({ + modelName, + }); const weatherTool = tool( (_) => "The weather in San Francisco today is 18 degrees and sunny.", @@ -474,7 +518,7 @@ describe("GAuth Gemini Chat", () => { const audioMimeType = "audio/wav"; const model = new ChatVertexAI({ - model: "gemini-1.5-flash", + model: modelName, temperature: 0, maxRetries: 0, }); From 14b3df1887bbb4c124b4cd07c03625792adbde20 Mon Sep 17 00:00:00 2001 From: afirstenberg Date: Tue, 24 Dec 2024 18:46:37 -0500 Subject: [PATCH 2/8] Test multiple models against all the tests. Add Gemini 2.0 googleSearch support. Add configuration to determine if the search request should be modified based on the model. --- .../src/chat_models.ts | 25 +++- libs/langchain-google-common/src/types.ts | 41 ++++++- .../src/utils/common.ts | 20 +++- .../src/utils/gemini.ts | 38 +++++- .../src/tests/chat_models.int.test.ts | 110 ++++++++++-------- 5 files changed, 177 insertions(+), 57 deletions(-) diff --git a/libs/langchain-google-common/src/chat_models.ts b/libs/langchain-google-common/src/chat_models.ts index 7d476b94cf2c..f5737dbd81c4 100644 --- a/libs/langchain-google-common/src/chat_models.ts +++ b/libs/langchain-google-common/src/chat_models.ts @@ -33,6 +33,7 @@ import { GoogleAIBaseLanguageModelCallOptions, GoogleAIAPI, GoogleAIAPIParams, + GoogleSearchToolSetting, } from "./types.js"; import { convertToGeminiTools, @@ -97,10 +98,32 @@ export class ChatConnection extends AbstractGoogleLLMConnection< return true; } + computeGoogleSearchToolAdjustmentFromModel(): Exclude { + if (this.modelName.startsWith("gemini-1.0")) { + return "googleSearchRetrieval"; + } else if (this.modelName.startsWith("gemini-1.5")) { + return "googleSearchRetrieval"; + } else { + return "googleSearch"; + } + } + + computeGoogleSearchToolAdjustment(apiConfig: GeminiAPIConfig): Exclude { + const adj = apiConfig.googleSearchToolAdjustment; + if (adj === undefined || adj === true) { + return this.computeGoogleSearchToolAdjustmentFromModel(); + } else { + return adj; + } + } + buildGeminiAPI(): GoogleAIAPI { + const apiConfig: GeminiAPIConfig = this.apiConfig as GeminiAPIConfig ?? {}; + const googleSearchToolAdjustment = this.computeGoogleSearchToolAdjustment(apiConfig); const geminiConfig: GeminiAPIConfig = { useSystemInstruction: this.useSystemInstruction, - ...(this.apiConfig as GeminiAPIConfig), + googleSearchToolAdjustment, + ...apiConfig, }; return getGeminiAPI(geminiConfig); } diff --git a/libs/langchain-google-common/src/types.ts b/libs/langchain-google-common/src/types.ts index b88b3e01d090..dfbcb0b334ef 100644 --- a/libs/langchain-google-common/src/types.ts +++ b/libs/langchain-google-common/src/types.ts @@ -307,12 +307,37 @@ export interface GeminiContent { role: GeminiRole; // Vertex AI requires the role } +/* + * If additional attributes are added here, they should also be + * added to the attributes below + */ export interface GeminiTool { functionDeclarations?: GeminiFunctionDeclaration[]; - googleSearchRetrieval?: GoogleSearchRetrieval; + googleSearchRetrieval?: GoogleSearchRetrieval; // Gemini-1.5 + googleSearch?: GoogleSearch; // Gemini-2.0 retrieval?: VertexAIRetrieval; } +/* + * The known strings in this type should match those in GeminiSearchToolAttribuets + */ +export type GoogleSearchToolSetting = + | boolean + | "googleSearchRetrieval" + | "googleSearch" + | string; + +export const GeminiSearchToolAttributes = [ + "googleSearchRetrieval", + "googleSearch", +] + +export const GeminiToolAttributes = [ + "functionDeclaration", + "retrieval", + ...GeminiSearchToolAttributes, +] + export interface GoogleSearchRetrieval { dynamicRetrievalConfig?: { mode?: string; @@ -320,6 +345,8 @@ export interface GoogleSearchRetrieval { }; } +export interface GoogleSearch {} + export interface VertexAIRetrieval { vertexAiSearch: { datastore: string; @@ -467,6 +494,18 @@ export interface GeminiAPIConfig { safetyHandler?: GoogleAISafetyHandler; mediaManager?: MediaManager; useSystemInstruction?: boolean; + + /** + * How to handle the Google Search tool, since the name (and format) + * of the tool changes between Gemini 1.5 and Gemini 2.0. + * true - Change based on the model version. (Default) + * false - Do not change the tool name provided + * string value - Use this as the attribute name for the search + * tool, adapting any tool attributes if possible. + * When the model is created, a "true" or default setting + * will be changed to a string based on the model. + */ + googleSearchToolAdjustment?: GoogleSearchToolSetting; } export type GoogleAIAPIConfig = GeminiAPIConfig | AnthropicAPIConfig; diff --git a/libs/langchain-google-common/src/utils/common.ts b/libs/langchain-google-common/src/utils/common.ts index b40ce25fe3fc..ea941082e2db 100644 --- a/libs/langchain-google-common/src/utils/common.ts +++ b/libs/langchain-google-common/src/utils/common.ts @@ -1,10 +1,11 @@ import { isOpenAITool } from "@langchain/core/language_models/base"; import { isLangChainTool } from "@langchain/core/utils/function_calling"; import { isModelGemini, validateGeminiParams } from "./gemini.js"; -import type { +import { GeminiFunctionDeclaration, GeminiFunctionSchema, GeminiTool, + GeminiToolAttributes, GoogleAIBaseLanguageModelCallOptions, GoogleAIModelParams, GoogleAIModelRequestParams, @@ -61,12 +62,25 @@ function processToolChoice( throw new Error("Object inputs for tool_choice not supported."); } +function isGeminiTool(tool: GoogleAIToolType): boolean { + for (const toolAttribute of GeminiToolAttributes) { + if (toolAttribute in tool) { + return true; + } + } + return false; +} + +function isGeminiNonFunctionTool(tool: GoogleAIToolType): boolean { + return isGeminiTool(tool) && !("functionDeclaration" in tool); +} + export function convertToGeminiTools(tools: GoogleAIToolType[]): GeminiTool[] { const geminiTools: GeminiTool[] = []; let functionDeclarationsIndex = -1; tools.forEach((tool) => { - if ("googleSearchRetrieval" in tool || "retrieval" in tool) { - geminiTools.push(tool); + if (isGeminiNonFunctionTool(tool)) { + geminiTools.push(tool as GeminiTool); } else { if (functionDeclarationsIndex === -1) { geminiTools.push({ diff --git a/libs/langchain-google-common/src/utils/gemini.ts b/libs/langchain-google-common/src/utils/gemini.ts index 23c46f3783db..8e61602ebc47 100644 --- a/libs/langchain-google-common/src/utils/gemini.ts +++ b/libs/langchain-google-common/src/utils/gemini.ts @@ -48,6 +48,7 @@ import { GeminiTool, GoogleAIModelRequestParams, GoogleAIToolType, + GeminiSearchToolAttributes, } from "../types.js"; import { zodToGeminiParameters } from "./zod_to_gemini_parameters.js"; @@ -1015,17 +1016,44 @@ export function getGeminiAPI(config?: GeminiAPIConfig): GoogleAIAPI { }; } + function searchToolName(tool: GeminiTool): string | undefined { + for (const name of GeminiSearchToolAttributes) { + if (name in tool) { + return name + } + } + return undefined; + } + + function cleanGeminiTool(tool: GeminiTool): GeminiTool { + const orig = searchToolName(tool); + const adj = config?.googleSearchToolAdjustment; + if (orig && adj && adj !== orig) { + return { + [adj as string]: {}, + } + } else { + return tool; + } + } + function formatTools(parameters: GoogleAIModelRequestParams): GeminiTool[] { const tools: GoogleAIToolType[] | undefined = parameters?.tools; if (!tools || tools.length === 0) { return []; } - // Group all LangChain tools into a single functionDeclarations array - const langChainTools = tools.filter(isLangChainTool); - const otherTools = tools.filter( - (tool) => !isLangChainTool(tool) - ) as GeminiTool[]; + // Group all LangChain tools into a single functionDeclarations array. + // Gemini Tools may be normalized to different tool names + const langChainTools: StructuredToolParams[] = []; + const otherTools: GeminiTool[] = []; + tools.forEach(tool => { + if (isLangChainTool(tool)) { + langChainTools.push(tool); + } else { + otherTools.push(cleanGeminiTool(tool as GeminiTool)); + } + }) const result: GeminiTool[] = [...otherTools]; diff --git a/libs/langchain-google-vertexai/src/tests/chat_models.int.test.ts b/libs/langchain-google-vertexai/src/tests/chat_models.int.test.ts index 74e743765c9d..49faa348a82f 100644 --- a/libs/langchain-google-vertexai/src/tests/chat_models.int.test.ts +++ b/libs/langchain-google-vertexai/src/tests/chat_models.int.test.ts @@ -66,6 +66,7 @@ const testGeminiModelNames = [ ["gemini-1.5-pro-002"], ["gemini-1.5-flash-002"], ["gemini-2.0-flash-exp"], + // ["gemini-2.0-flash-thinking-exp-1219"], ] /* @@ -74,6 +75,7 @@ const testGeminiModelNames = [ */ const testGeminiModelDelay: Record = { "gemini-2.0-flash-exp": 5000, + "gemini-2.0-flash-thinking-exp-1219": 5000, } describe.each(testGeminiModelNames)("GAuth Gemini Chat (%s)", (modelName) => { @@ -549,6 +551,67 @@ describe.each(testGeminiModelNames)("GAuth Gemini Chat (%s)", (modelName) => { expect(typeof response.content).toBe("string"); expect((response.content as string).length).toBeGreaterThan(15); }); + + test("Supports GoogleSearchRetrievalTool", async () => { + const searchRetrievalTool = { + googleSearchRetrieval: { + dynamicRetrievalConfig: { + mode: "MODE_DYNAMIC", + dynamicThreshold: 0.7, // default is 0.7 + }, + }, + }; + const model = new ChatVertexAI({ + modelName, + temperature: 0, + maxRetries: 0, + }).bindTools([searchRetrievalTool]); + + const result = await model.invoke("Who won the 2024 MLB World Series?"); + expect(result.content as string).toContain("Dodgers"); + }); + + test("Supports GoogleSearchTool", async () => { + const searchTool: GeminiTool = { + googleSearch: { + }, + }; + const model = new ChatVertexAI({ + modelName, + temperature: 0, + maxRetries: 0, + }).bindTools([searchTool]); + + const result = await model.invoke("Who won the 2024 MLB World Series?"); + expect(result.content as string).toContain("Dodgers"); + }); + + test("Can stream GoogleSearchRetrievalTool", async () => { + const searchRetrievalTool = { + googleSearchRetrieval: { + dynamicRetrievalConfig: { + mode: "MODE_DYNAMIC", + dynamicThreshold: 0.7, // default is 0.7 + }, + }, + }; + const model = new ChatVertexAI({ + modelName, + temperature: 0, + maxRetries: 0, + }).bindTools([searchRetrievalTool]); + + const stream = await model.stream("Who won the 2024 MLB World Series?"); + let finalMsg: AIMessageChunk | undefined; + for await (const msg of stream) { + finalMsg = finalMsg ? concat(finalMsg, msg) : msg; + } + if (!finalMsg) { + throw new Error("finalMsg is undefined"); + } + expect(finalMsg.content as string).toContain("Dodgers"); + }); + }); describe("GAuth Anthropic Chat", () => { @@ -661,50 +724,3 @@ describe("GAuth Anthropic Chat", () => { expect(toolCalls?.[0].args).toHaveProperty("location"); }); }); - -describe("GoogleSearchRetrievalTool", () => { - test("Supports GoogleSearchRetrievalTool", async () => { - const searchRetrievalTool = { - googleSearchRetrieval: { - dynamicRetrievalConfig: { - mode: "MODE_DYNAMIC", - dynamicThreshold: 0.7, // default is 0.7 - }, - }, - }; - const model = new ChatVertexAI({ - model: "gemini-1.5-pro", - temperature: 0, - maxRetries: 0, - }).bindTools([searchRetrievalTool]); - - const result = await model.invoke("Who won the 2024 MLB World Series?"); - expect(result.content as string).toContain("Dodgers"); - }); - - test("Can stream GoogleSearchRetrievalTool", async () => { - const searchRetrievalTool = { - googleSearchRetrieval: { - dynamicRetrievalConfig: { - mode: "MODE_DYNAMIC", - dynamicThreshold: 0.7, // default is 0.7 - }, - }, - }; - const model = new ChatVertexAI({ - model: "gemini-1.5-pro", - temperature: 0, - maxRetries: 0, - }).bindTools([searchRetrievalTool]); - - const stream = await model.stream("Who won the 2024 MLB World Series?"); - let finalMsg: AIMessageChunk | undefined; - for await (const msg of stream) { - finalMsg = finalMsg ? concat(finalMsg, msg) : msg; - } - if (!finalMsg) { - throw new Error("finalMsg is undefined"); - } - expect(finalMsg.content as string).toContain("Dodgers"); - }); -}); From 90f12a2429d0c93936a83b8caaa5fca69185fa29 Mon Sep 17 00:00:00 2001 From: afirstenberg Date: Tue, 24 Dec 2024 22:31:27 -0500 Subject: [PATCH 3/8] Only try to use API Key if we're not on Google Cloud Platform / Vertex AI. Addresses #7399 and possibly other issues. --- libs/langchain-google-common/src/chat_models.ts | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/libs/langchain-google-common/src/chat_models.ts b/libs/langchain-google-common/src/chat_models.ts index f5737dbd81c4..a9cfbcb951ad 100644 --- a/libs/langchain-google-common/src/chat_models.ts +++ b/libs/langchain-google-common/src/chat_models.ts @@ -231,7 +231,12 @@ export abstract class ChatGoogleBase } buildApiKey(fields?: GoogleAIBaseLLMInput): string | undefined { - return fields?.apiKey ?? getEnvironmentVariable("GOOGLE_API_KEY"); + if (fields?.platformType !== "gcp") { + return fields?.apiKey ?? getEnvironmentVariable("GOOGLE_API_KEY"); + } else { + // GCP doesn't support API Keys + return undefined; + } } buildClient( From 406ee34f8480d5b848e2eb6a6e339db200618556 Mon Sep 17 00:00:00 2001 From: afirstenberg Date: Tue, 24 Dec 2024 22:33:03 -0500 Subject: [PATCH 4/8] Test multiple models against all the tests for both GCP and GAI platforms. (There are some known, and expected, failures.) --- .../src/tests/chat_models.int.test.ts | 582 +++++++++++++++++- 1 file changed, 578 insertions(+), 4 deletions(-) diff --git a/libs/langchain-google-webauth/src/tests/chat_models.int.test.ts b/libs/langchain-google-webauth/src/tests/chat_models.int.test.ts index 0e10359599b3..b406d61098a3 100644 --- a/libs/langchain-google-webauth/src/tests/chat_models.int.test.ts +++ b/libs/langchain-google-webauth/src/tests/chat_models.int.test.ts @@ -1,12 +1,12 @@ /* eslint-disable import/no-extraneous-dependencies */ -import { StructuredTool } from "@langchain/core/tools"; +import { StructuredTool, tool } from "@langchain/core/tools"; import { z } from "zod"; -import { test } from "@jest/globals"; +import { expect, test } from "@jest/globals"; import { AIMessage, AIMessageChunk, BaseMessage, - BaseMessageChunk, + BaseMessageChunk, BaseMessageLike, HumanMessage, HumanMessageChunk, MessageContentComplex, @@ -16,10 +16,21 @@ import { import { BaseLanguageModelInput } from "@langchain/core/language_models/base"; import { ChatPromptValue } from "@langchain/core/prompt_values"; import { + BackedBlobStore, + MediaBlob, MediaManager, + ReadThroughBlobStore, SimpleWebBlobStore, } from "@langchain/google-common/experimental/utils/media_core"; -import { ChatGoogle } from "../chat_models.js"; +import {GeminiTool, GooglePlatformType, GoogleRequestLogger, GoogleRequestRecorder} from "@langchain/google-common"; +import {BaseCallbackHandler} from "@langchain/core/callbacks/base"; +import {InMemoryStore} from "@langchain/core/stores"; +import {BlobStoreGoogleCloudStorage} from "@langchain/google-gauth"; +import {GoogleCloudStorageUri} from "@langchain/google-common/experimental/media"; +import {concat} from "@langchain/core/utils/stream"; +import fs from "fs/promises"; +import {ChatPromptTemplate, MessagesPlaceholder} from "@langchain/core/prompts"; +import {ChatGoogle, ChatGoogleInput} from "../chat_models.js"; import { BlobStoreAIStudioFile } from "../media.js"; class WeatherTool extends StructuredTool { @@ -247,3 +258,566 @@ describe("Google APIKey Chat", () => { } }); }); + +const weatherTool = tool((_) => "no-op", { + name: "get_weather", + description: + "Get the weather of a specific location and return the temperature in Celsius.", + schema: z.object({ + location: z.string().describe("The name of city to get the weather for."), + }), +}); + +const calculatorTool = tool((_) => "no-op", { + name: "calculator", + description: "Calculate the result of a math expression.", + schema: z.object({ + expression: z.string().describe("The math expression to calculate."), + }), +}); + +/* + * Which models do we want to run the test suite against + * and on which platforms? + */ +const testGeminiModelNames = [ + {modelName: "gemini-1.5-pro-002", platformType: "gai", apiVersion: "v1beta"}, + {modelName: "gemini-1.5-pro-002", platformType: "gcp", apiVersion: "v1"}, + {modelName: "gemini-1.5-flash-002", platformType: "gai", apiVersion: "v1beta"}, + {modelName: "gemini-1.5-flash-002", platformType: "gcp", apiVersion: "v1"}, + {modelName: "gemini-2.0-flash-exp", platformType: "gai", apiVersion: "v1beta"}, + {modelName: "gemini-2.0-flash-exp", platformType: "gcp", apiVersion: "v1"}, + // {modelName: "gemini-2.0-flash-thinking-exp", platformType: "gai"}, + // {modelName: "gemini-2.0-flash-thinking-exp", platformType: "gcp"}, +] + +/* + * Some models may have usage quotas still. + * For those models, set how long (in millis) to wait in between each test. + */ +const testGeminiModelDelay: Record = { + "gemini-2.0-flash-exp": 10000, + "gemini-2.0-flash-thinking-exp-1219": 10000, +} + +describe.each(testGeminiModelNames)("Webauth ($platformType) Gemini Chat ($modelName)", ({modelName, platformType, apiVersion}) => { + let recorder: GoogleRequestRecorder; + let callbacks: BaseCallbackHandler[]; + + function newChatGoogle(fields?: ChatGoogleInput): ChatGoogle { + return new ChatGoogle({ + modelName, + platformType: platformType as GooglePlatformType, + apiVersion, + ...fields ?? {}, + }) + } + + beforeEach(async () => { + recorder = new GoogleRequestRecorder(); + callbacks = [recorder, new GoogleRequestLogger()]; + + const delay = testGeminiModelDelay[modelName] ?? 0; + if (delay) { + console.log(`Delaying for ${delay}ms`) + // eslint-disable-next-line no-promise-executor-return + await new Promise(resolve => setTimeout(resolve,delay)); + } + }); + + test("invoke", async () => { + const model = newChatGoogle({ + callbacks, + }); + const res = await model.invoke("What is 1 + 1?"); + expect(res).toBeDefined(); + expect(res._getType()).toEqual("ai"); + + const aiMessage = res as AIMessageChunk; + expect(aiMessage.content).toBeDefined(); + + expect(typeof aiMessage.content).toBe("string"); + const text = aiMessage.content as string; + expect(text).toMatch(/(1 + 1 (equals|is|=) )?2.? ?/); + }); + + test(`generate`, async () => { + const model = newChatGoogle(); + const messages: BaseMessage[] = [ + new SystemMessage( + "You will reply to all requests to flip a coin with either H, indicating heads, or T, indicating tails." + ), + new HumanMessage("Flip it"), + new AIMessage("T"), + new HumanMessage("Flip the coin again"), + ]; + const res = await model.predictMessages(messages); + expect(res).toBeDefined(); + expect(res._getType()).toEqual("ai"); + + const aiMessage = res as AIMessageChunk; + expect(aiMessage.content).toBeDefined(); + + expect(typeof aiMessage.content).toBe("string"); + const text = aiMessage.content as string; + expect(["H", "T"]).toContainEqual(text.trim()); + }); + + test("stream", async () => { + const model = newChatGoogle({ + callbacks, + }); + const input: BaseLanguageModelInput = new ChatPromptValue([ + new SystemMessage( + "You will reply to all requests to flip a coin with either H, indicating heads, or T, indicating tails." + ), + new HumanMessage("Flip it"), + new AIMessage("T"), + new HumanMessage("Flip the coin again"), + ]); + const res = await model.stream(input); + const resArray: BaseMessageChunk[] = []; + for await (const chunk of res) { + resArray.push(chunk); + } + expect(resArray).toBeDefined(); + expect(resArray.length).toBeGreaterThanOrEqual(1); + + const lastChunk = resArray[resArray.length - 1]; + expect(lastChunk).toBeDefined(); + expect(lastChunk._getType()).toEqual("ai"); + }); + + test("function", async () => { + const tools: GeminiTool[] = [ + { + functionDeclarations: [ + { + name: "test", + description: + "Run a test with a specific name and get if it passed or failed", + parameters: { + type: "object", + properties: { + testName: { + type: "string", + description: "The name of the test that should be run.", + }, + }, + required: ["testName"], + }, + }, + ], + }, + ]; + const model = newChatGoogle().bind({ + tools, + }); + const result = await model.invoke("Run a test on the cobalt project"); + expect(result).toHaveProperty("content"); + expect(result.content).toBe(""); + const args = result?.lc_kwargs?.additional_kwargs; + expect(args).toBeDefined(); + expect(args).toHaveProperty("tool_calls"); + expect(Array.isArray(args.tool_calls)).toBeTruthy(); + expect(args.tool_calls).toHaveLength(1); + const call = args.tool_calls[0]; + expect(call).toHaveProperty("type"); + expect(call.type).toBe("function"); + expect(call).toHaveProperty("function"); + const func = call.function; + expect(func).toBeDefined(); + expect(func).toHaveProperty("name"); + expect(func.name).toBe("test"); + expect(func).toHaveProperty("arguments"); + expect(typeof func.arguments).toBe("string"); + expect(func.arguments.replaceAll("\n", "")).toBe('{"testName":"cobalt"}'); + }); + + test("function reply", async () => { + const tools: GeminiTool[] = [ + { + functionDeclarations: [ + { + name: "test", + description: + "Run a test with a specific name and get if it passed or failed", + parameters: { + type: "object", + properties: { + testName: { + type: "string", + description: "The name of the test that should be run.", + }, + }, + required: ["testName"], + }, + }, + ], + }, + ]; + const model = newChatGoogle().bind({ + tools, + }); + const toolResult = { + testPassed: true, + }; + const messages: BaseMessageLike[] = [ + new HumanMessage("Run a test on the cobalt project."), + new AIMessage("", { + tool_calls: [ + { + id: "test", + type: "function", + function: { + name: "test", + arguments: '{"testName":"cobalt"}', + }, + }, + ], + }), + new ToolMessage(JSON.stringify(toolResult), "test"), + ]; + const res = await model.stream(messages); + const resArray: BaseMessageChunk[] = []; + for await (const chunk of res) { + resArray.push(chunk); + } + // console.log(JSON.stringify(resArray, null, 2)); + }); + + test("withStructuredOutput", async () => { + const tool = { + name: "get_weather", + description: + "Get the weather of a specific location and return the temperature in Celsius.", + parameters: { + type: "object", + properties: { + location: { + type: "string", + description: "The name of city to get the weather for.", + }, + }, + required: ["location"], + }, + }; + const model = newChatGoogle().withStructuredOutput(tool); + const result = await model.invoke("What is the weather in Paris?"); + expect(result).toHaveProperty("location"); + }); + + test("media - fileData", async () => { + class MemStore extends InMemoryStore { + get length() { + return Object.keys(this.store).length; + } + } + const aliasMemory = new MemStore(); + const aliasStore = new BackedBlobStore({ + backingStore: aliasMemory, + defaultFetchOptions: { + actionIfBlobMissing: undefined, + }, + }); + const backingStore = new BlobStoreGoogleCloudStorage({ + uriPrefix: new GoogleCloudStorageUri("gs://test-langchainjs/mediatest/"), + defaultStoreOptions: { + actionIfInvalid: "prefixPath", + }, + }); + const blobStore = new ReadThroughBlobStore({ + baseStore: aliasStore, + backingStore, + }); + const resolver = new SimpleWebBlobStore(); + const mediaManager = new MediaManager({ + store: blobStore, + resolvers: [resolver], + }); + const model = newChatGoogle({ + apiConfig: { + mediaManager, + }, + }); + + const message: MessageContentComplex[] = [ + { + type: "text", + text: "What is in this image?", + }, + { + type: "media", + fileUri: "https://js.langchain.com/v0.2/img/brand/wordmark.png", + }, + ]; + + const messages: BaseMessage[] = [ + new HumanMessageChunk({ content: message }), + ]; + + try { + const res = await model.invoke(messages); + + console.log(res); + + expect(res).toBeDefined(); + expect(res._getType()).toEqual("ai"); + + const aiMessage = res as AIMessageChunk; + expect(aiMessage.content).toBeDefined(); + + expect(typeof aiMessage.content).toBe("string"); + const text = aiMessage.content as string; + expect(text).toMatch(/LangChain/); + } catch (e) { + console.error(e); + throw e; + } + }); + + test("Stream token count usage_metadata", async () => { + const model = newChatGoogle({ + temperature: 0, + maxOutputTokens: 10, + }); + let res: AIMessageChunk | null = null; + for await (const chunk of await model.stream( + "Why is the sky blue? Be concise." + )) { + if (!res) { + res = chunk; + } else { + res = res.concat(chunk); + } + } + // console.log(res); + expect(res?.usage_metadata).toBeDefined(); + if (!res?.usage_metadata) { + return; + } + expect(res.usage_metadata.input_tokens).toBeGreaterThan(1); + expect(res.usage_metadata.output_tokens).toBeGreaterThan(1); + expect(res.usage_metadata.total_tokens).toBe( + res.usage_metadata.input_tokens + res.usage_metadata.output_tokens + ); + }); + + test("streamUsage excludes token usage", async () => { + const model = newChatGoogle({ + temperature: 0, + streamUsage: false, + }); + let res: AIMessageChunk | null = null; + for await (const chunk of await model.stream( + "Why is the sky blue? Be concise." + )) { + if (!res) { + res = chunk; + } else { + res = res.concat(chunk); + } + } + // console.log(res); + expect(res?.usage_metadata).not.toBeDefined(); + }); + + test("Invoke token count usage_metadata", async () => { + const model = newChatGoogle({ + temperature: 0, + maxOutputTokens: 10, + }); + const res = await model.invoke("Why is the sky blue? Be concise."); + // console.log(res); + expect(res?.usage_metadata).toBeDefined(); + if (!res?.usage_metadata) { + return; + } + expect(res.usage_metadata.input_tokens).toBeGreaterThan(1); + expect(res.usage_metadata.output_tokens).toBeGreaterThan(1); + expect(res.usage_metadata.total_tokens).toBe( + res.usage_metadata.input_tokens + res.usage_metadata.output_tokens + ); + }); + + test("Streaming true constructor param will stream", async () => { + const modelWithStreaming = newChatGoogle({ + maxOutputTokens: 50, + streaming: true, + }); + + let totalTokenCount = 0; + let tokensString = ""; + const result = await modelWithStreaming.invoke("What is 1 + 1?", { + callbacks: [ + { + handleLLMNewToken: (tok) => { + totalTokenCount += 1; + tokensString += tok; + }, + }, + ], + }); + + expect(result).toBeDefined(); + expect(result.content).toBe(tokensString); + + expect(totalTokenCount).toBeGreaterThan(1); + }); + + test("Can force a model to invoke a tool", async () => { + const model = newChatGoogle(); + const modelWithTools = model.bind({ + tools: [calculatorTool, weatherTool], + tool_choice: "calculator", + }); + + const result = await modelWithTools.invoke( + "Whats the weather like in paris today? What's 1836 plus 7262?" + ); + + expect(result.tool_calls).toHaveLength(1); + expect(result.tool_calls?.[0]).toBeDefined(); + if (!result.tool_calls?.[0]) return; + expect(result.tool_calls?.[0].name).toBe("calculator"); + expect(result.tool_calls?.[0].args).toHaveProperty("expression"); + }); + + test(`stream tools`, async () => { + const model = newChatGoogle(); + + const weatherTool = tool( + (_) => "The weather in San Francisco today is 18 degrees and sunny.", + { + name: "current_weather_tool", + description: "Get the current weather for a given location.", + schema: z.object({ + location: z.string().describe("The location to get the weather for."), + }), + } + ); + + const modelWithTools = model.bindTools([weatherTool]); + const stream = await modelWithTools.stream( + "Whats the weather like today in San Francisco?" + ); + let finalChunk: AIMessageChunk | undefined; + for await (const chunk of stream) { + finalChunk = !finalChunk ? chunk : concat(finalChunk, chunk); + } + + expect(finalChunk).toBeDefined(); + if (!finalChunk) return; + + const toolCalls = finalChunk.tool_calls; + expect(toolCalls).toBeDefined(); + if (!toolCalls) { + throw new Error("tool_calls not in response"); + } + expect(toolCalls.length).toBe(1); + expect(toolCalls[0].name).toBe("current_weather_tool"); + expect(toolCalls[0].args).toHaveProperty("location"); + }); + + async function fileToBase64(filePath: string): Promise { + const fileData = await fs.readFile(filePath); + const base64String = Buffer.from(fileData).toString("base64"); + return base64String; + } + + test("Gemini can understand audio", async () => { + // Update this with the correct path to an audio file on your machine. + const audioPath = + "../langchain-google-genai/src/tests/data/gettysburg10.wav"; + const audioMimeType = "audio/wav"; + + const model = newChatGoogle({ + temperature: 0, + maxRetries: 0, + }); + + const audioBase64 = await fileToBase64(audioPath); + + const prompt = ChatPromptTemplate.fromMessages([ + new MessagesPlaceholder("audio"), + ]); + + const chain = prompt.pipe(model); + const response = await chain.invoke({ + audio: new HumanMessage({ + content: [ + { + type: "media", + mimeType: audioMimeType, + data: audioBase64, + }, + { + type: "text", + text: "Summarize the content in this audio. ALso, what is the speaker's tone?", + }, + ], + }), + }); + + expect(typeof response.content).toBe("string"); + expect((response.content as string).length).toBeGreaterThan(15); + }); + + test("Supports GoogleSearchRetrievalTool", async () => { + const searchRetrievalTool = { + googleSearchRetrieval: { + dynamicRetrievalConfig: { + mode: "MODE_DYNAMIC", + dynamicThreshold: 0.7, // default is 0.7 + }, + }, + }; + const model = newChatGoogle({ + temperature: 0, + maxRetries: 0, + }).bindTools([searchRetrievalTool]); + + const result = await model.invoke("Who won the 2024 MLB World Series?"); + expect(result.content as string).toContain("Dodgers"); + }); + + test("Supports GoogleSearchTool", async () => { + const searchTool: GeminiTool = { + googleSearch: { + }, + }; + const model = newChatGoogle({ + temperature: 0, + maxRetries: 0, + }).bindTools([searchTool]); + + const result = await model.invoke("Who won the 2024 MLB World Series?"); + expect(result.content as string).toContain("Dodgers"); + }); + + test("Can stream GoogleSearchRetrievalTool", async () => { + const searchRetrievalTool = { + googleSearchRetrieval: { + dynamicRetrievalConfig: { + mode: "MODE_DYNAMIC", + dynamicThreshold: 0.7, // default is 0.7 + }, + }, + }; + const model = newChatGoogle({ + temperature: 0, + maxRetries: 0, + }).bindTools([searchRetrievalTool]); + + const stream = await model.stream("Who won the 2024 MLB World Series?"); + let finalMsg: AIMessageChunk | undefined; + for await (const msg of stream) { + finalMsg = finalMsg ? concat(finalMsg, msg) : msg; + } + if (!finalMsg) { + throw new Error("finalMsg is undefined"); + } + expect(finalMsg.content as string).toContain("Dodgers"); + }); + +}); From 519b8d8a3bbce7f19bc697fe6cee3f2db7a05421 Mon Sep 17 00:00:00 2001 From: afirstenberg Date: Thu, 26 Dec 2024 18:23:49 -0500 Subject: [PATCH 5/8] Make sure grounding results are available. --- .../src/tests/chat_models.test.ts | 144 ++++++++++++++++++ .../src/tests/data/chat-6-mock.json | 103 +++++++++++++ libs/langchain-google-common/src/types.ts | 67 ++++++++ .../src/utils/gemini.ts | 67 +++++++- .../src/tests/chat_models.int.test.ts | 33 ++-- 5 files changed, 401 insertions(+), 13 deletions(-) create mode 100644 libs/langchain-google-common/src/tests/data/chat-6-mock.json diff --git a/libs/langchain-google-common/src/tests/chat_models.test.ts b/libs/langchain-google-common/src/tests/chat_models.test.ts index aa15be74ed79..59998cf24768 100644 --- a/libs/langchain-google-common/src/tests/chat_models.test.ts +++ b/libs/langchain-google-common/src/tests/chat_models.test.ts @@ -1105,6 +1105,150 @@ describe("Mock ChatGoogle - Gemini", () => { // console.log(JSON.stringify(record?.opts?.data, null, 1)); }); + + test("6. GoogleSearchRetrievalTool result", async () => { + const record: Record = {}; + const projectId = mockId(); + const authOptions: MockClientAuthInfo = { + record, + projectId, + resultFile: "chat-6-mock.json", + }; + + const searchRetrievalTool = { + googleSearchRetrieval: { + dynamicRetrievalConfig: { + mode: "MODE_DYNAMIC", + dynamicThreshold: 0.7, // default is 0.7 + }, + }, + }; + const model = new ChatGoogle({ + authOptions, + modelName: "gemini-1.5-pro-002", + temperature: 0, + maxRetries: 0, + }).bindTools([searchRetrievalTool]); + + const result = await model.invoke("Who won the 2024 MLB World Series?"); + expect(result.content as string).toContain("Dodgers"); + expect(result).toHaveProperty("response_metadata"); + expect(result.response_metadata).toHaveProperty("groundingMetadata"); + expect(result.response_metadata).toHaveProperty("groundingSupport"); + expect(Array.isArray(result.response_metadata.groundingSupport)).toEqual(true); + expect(result.response_metadata.groundingSupport).toHaveLength(4); + }); + + test("6. GoogleSearchRetrievalTool request 1.5 ", async () => { + const record: Record = {}; + const projectId = mockId(); + const authOptions: MockClientAuthInfo = { + record, + projectId, + resultFile: "chat-6-mock.json", + }; + + const searchRetrievalTool = { + googleSearchRetrieval: { + dynamicRetrievalConfig: { + mode: "MODE_DYNAMIC", + dynamicThreshold: 0.7, // default is 0.7 + }, + }, + }; + const model = new ChatGoogle({ + authOptions, + modelName: "gemini-1.5-pro-002", + temperature: 0, + maxRetries: 0, + }).bindTools([searchRetrievalTool]); + + const result = await model.invoke("Who won the 2024 MLB World Series?"); + expect(result.content as string).toContain("Dodgers"); + + expect(record.opts.data.tools[0]).toHaveProperty("googleSearchRetrieval"); + }); + + test("6. GoogleSearchRetrievalTool request 2.0 ", async () => { + const record: Record = {}; + const projectId = mockId(); + const authOptions: MockClientAuthInfo = { + record, + projectId, + resultFile: "chat-6-mock.json", + }; + + const searchRetrievalTool = { + googleSearchRetrieval: { + dynamicRetrievalConfig: { + mode: "MODE_DYNAMIC", + dynamicThreshold: 0.7, // default is 0.7 + }, + }, + }; + const model = new ChatGoogle({ + authOptions, + modelName: "gemini-2.0-flash", + temperature: 0, + maxRetries: 0, + }).bindTools([searchRetrievalTool]); + + const result = await model.invoke("Who won the 2024 MLB World Series?"); + expect(result.content as string).toContain("Dodgers"); + + expect(record.opts.data.tools[0]).toHaveProperty("googleSearch"); + }); + + test("6. GoogleSearchTool request 1.5 ", async () => { + const record: Record = {}; + const projectId = mockId(); + const authOptions: MockClientAuthInfo = { + record, + projectId, + resultFile: "chat-6-mock.json", + }; + + const searchTool = { + googleSearch: {}, + }; + const model = new ChatGoogle({ + authOptions, + modelName: "gemini-1.5-pro-002", + temperature: 0, + maxRetries: 0, + }).bindTools([searchTool]); + + const result = await model.invoke("Who won the 2024 MLB World Series?"); + expect(result.content as string).toContain("Dodgers"); + + expect(record.opts.data.tools[0]).toHaveProperty("googleSearchRetrieval"); + }); + + test("6. GoogleSearchTool request 2.0 ", async () => { + const record: Record = {}; + const projectId = mockId(); + const authOptions: MockClientAuthInfo = { + record, + projectId, + resultFile: "chat-6-mock.json", + }; + + const searchTool = { + googleSearch: {}, + }; + const model = new ChatGoogle({ + authOptions, + modelName: "gemini-2.0-flash", + temperature: 0, + maxRetries: 0, + }).bindTools([searchTool]); + + const result = await model.invoke("Who won the 2024 MLB World Series?"); + expect(result.content as string).toContain("Dodgers"); + + expect(record.opts.data.tools[0]).toHaveProperty("googleSearch"); + }); + }); describe("Mock ChatGoogle - Anthropic", () => { diff --git a/libs/langchain-google-common/src/tests/data/chat-6-mock.json b/libs/langchain-google-common/src/tests/data/chat-6-mock.json new file mode 100644 index 000000000000..65568fb1810a --- /dev/null +++ b/libs/langchain-google-common/src/tests/data/chat-6-mock.json @@ -0,0 +1,103 @@ +{ + "candidates": [ + { + "content": { + "parts": [ + { + "text": "The Los Angeles Dodgers won the 2024 World Series, defeating the New York Yankees 4-1 in the series. The Dodgers clinched the title with a 7-6 comeback victory in Game 5 at Yankee Stadium on Wednesday, October 30th. This was their eighth World Series title overall and their second in the past five years. It was also their first World Series win in a full season since 1988. Mookie Betts earned his third World Series ring (2018, 2020, and 2024), becoming the only active player with three championships. Shohei Ohtani, in his first year with the Dodgers, also experienced his first post-season appearance.\n" + } + ], + "role": "model" + }, + "finishReason": "STOP", + "groundingMetadata": { + "searchEntryPoint": { + "renderedContent": "\n
\n
\n \n \n \n \n \n \n \n \n \n \n \n \n \n
\n
\n \n
\n" + }, + "groundingChunks": [ + { + "web": { + "uri": "https://vertexaisearch.cloud.google.com/grounding-api-redirect/AYygrcTYmdnM71OvWYUTG4JggmRj8cIIgA2KtKas5RPj09CiALB4n8hl-SfCD6r8WnimL2psBoYmEN9ng9sENjpeP5VxgLMTlm0zgxhrWFfx3yA6B_n0N9j-BgHLISAUi-_Ql4_Buyw68Svq-3v6BgrXzn9hLOtK", + "title": "bbc.com" + } + }, + { + "web": { + "uri": "https://vertexaisearch.cloud.google.com/grounding-api-redirect/AYygrcQRhhvHTdpb8OMOEMVxv9fkevPoMWMnhrpuC7E0E0R94xmFxT9Vv5na1hMrfHGKxVZ9aE3PgCAs5nftC3iAkeD7B6ZTfKGH2Im1CqssMM7zorGx1Ds5_7QPPBDQps_JvpkOuvRluGCVg8KwNaIU-hm3Kg==", + "title": "mlb.com" + } + }, + { + "web": { + "uri": "https://vertexaisearch.cloud.google.com/grounding-api-redirect/AYygrcSwvb2t622A2ZpKxqOWKy16L1mEUvmsAJoHjaR7uffKO71SeZkpdRXRsST9HJzJkGSkMF9kOaXGoDtcvUrttqKYOQHvHSUBYO7LWMlU00KyNlSoQzrBsgN4KuJ4O4acnNyNCSVX3-E=", + "title": "youtube.com" + } + } + ], + "groundingSupports": [ + { + "segment": { + "endIndex": 100, + "text": "The Los Angeles Dodgers won the 2024 World Series, defeating the New York Yankees 4-1 in the series." + }, + "groundingChunkIndices": [ + 0 + ], + "confidenceScores": [ + 0.95898277 + ] + }, + { + "segment": { + "startIndex": 308, + "endIndex": 377, + "text": "It was also their first World Series win in a full season since 1988." + }, + "groundingChunkIndices": [ + 1 + ], + "confidenceScores": [ + 0.96841997 + ] + }, + { + "segment": { + "startIndex": 379, + "endIndex": 508, + "text": "Mookie Betts earned his third World Series ring (2018, 2020, and 2024), becoming the only active player with three championships." + }, + "groundingChunkIndices": [ + 2 + ], + "confidenceScores": [ + 0.99043523 + ] + }, + { + "segment": { + "startIndex": 510, + "endIndex": 611, + "text": "Shohei Ohtani, in his first year with the Dodgers, also experienced his first post-season appearance." + }, + "groundingChunkIndices": [ + 0 + ], + "confidenceScores": [ + 0.95767003 + ] + } + ], + "webSearchQueries": [ + "2024 MLB World Series winner" + ] + }, + "avgLogprobs": -0.040494912748883484 + } + ], + "usageMetadata": { + "promptTokenCount": 13, + "candidatesTokenCount": 157, + "totalTokenCount": 170 + }, + "modelVersion": "gemini-1.5-pro-002" +} diff --git a/libs/langchain-google-common/src/types.ts b/libs/langchain-google-common/src/types.ts index dfbcb0b334ef..2706d190cc78 100644 --- a/libs/langchain-google-common/src/types.ts +++ b/libs/langchain-google-common/src/types.ts @@ -299,6 +299,71 @@ export type GeminiSafetyRating = { probability: string; } & Record; +export interface GeminiCitationMetadata { + citations: GeminiCitation[]; +} + +export interface GeminiCitation { + startIndex: number; + endIndex: number; + uri: string; + title: string; + license: string; + publicationDate: GoogleTypeDate; +} + +export interface GoogleTypeDate { + year: number; // 1-9999 or 0 to specify a date without a year + month: number; // 1-12 or 0 to specify a year without a month and day + day: number; // Must be from 1 to 31 and valid for the year and month, or 0 to specify a year by itself or a year and month where the day isn't significant +} + +export interface GeminiGroundingMetadata { + webSearchQueries?: string[]; + searchEntryPoint?: GeminiSearchEntryPoint; + groundingChunks: GeminiGroundingChunk[]; + groundingSupports?: GeminiGroundingSupport[]; + retrievalMetadata?: GeminiRetrievalMetadata; +} + +export interface GeminiSearchEntryPoint { + renderedContent?: string; + sdkBlob?: string; // Base64 encoded JSON representing array of tuple. +} + +export interface GeminiGroundingChunk { + web: GeminiGroundingChunkWeb; + retrievedContext: GeminiGroundingChunkRetrievedContext; +} + +export interface GeminiGroundingChunkWeb { + uri: string; + title: string; +} + +export interface GeminiGroundingChunkRetrievedContext { + uri: string; + title: string; + text: string; +} + +export interface GeminiGroundingSupport { + segment: GeminiSegment; + groundingChunkIndices: number[]; + confidenceScores: number[]; +} + +export interface GeminiSegment { + partIndex: number; + startIndex: number; + endIndex: number; + text: string; +} + +export interface GeminiRetrievalMetadata { + googleSearchDynamicRetrievalScore: number; +} + // The "system" content appears to only be valid in the systemInstruction export type GeminiRole = "system" | "user" | "model" | "function"; @@ -412,6 +477,8 @@ interface GeminiResponseCandidate { index: number; tokenCount?: number; safetyRatings: GeminiSafetyRating[]; + citationMetadata?: GeminiCitationMetadata; + groundingMetadata?: GeminiGroundingMetadata; } interface GeminiResponsePromptFeedback { diff --git a/libs/langchain-google-common/src/utils/gemini.ts b/libs/langchain-google-common/src/utils/gemini.ts index 8e61602ebc47..e6e122e85d3f 100644 --- a/libs/langchain-google-common/src/utils/gemini.ts +++ b/libs/langchain-google-common/src/utils/gemini.ts @@ -36,7 +36,7 @@ import type { GoogleAISafetyHandler, GeminiPartFunctionCall, GoogleAIAPI, - GeminiAPIConfig, + GeminiAPIConfig, GeminiGroundingSupport, } from "../types.js"; import { GoogleAISafetyError } from "./safety.js"; import { MediaBlob } from "../experimental/utils/media_core.js"; @@ -691,6 +691,8 @@ export function getGeminiAPI(config?: GeminiAPIConfig): GoogleAIAPI { severity: rating.severity, severity_score: rating.severityScore, })), + citation_metadata: data.candidates[0]?.citationMetadata, + grounding_metadata: data.candidates[0]?.groundingMetadata, finish_reason: data.candidates[0]?.finishReason, }; } @@ -750,7 +752,30 @@ export function getGeminiAPI(config?: GeminiAPIConfig): GoogleAIAPI { }); } - function responseToChatGenerations( + function groundingSupportByPart( + groundingSupports?: GeminiGroundingSupport[] + ): GeminiGroundingSupport[][] { + const ret: GeminiGroundingSupport[][] = []; + + if (!groundingSupports || groundingSupports.length === 0){ + return []; + } + + groundingSupports?.forEach((groundingSupport) => { + const segment = groundingSupport?.segment; + const partIndex = segment?.partIndex ?? 0; + if (ret[partIndex]) { + ret[partIndex].push(groundingSupport); + } else { + ret[partIndex] = [groundingSupport]; + } + + }); + + return ret; + } + + function responseToGroundedChatGenerations( response: GoogleLLMResponse ): ChatGeneration[] { const parts = responseToParts(response); @@ -759,7 +784,43 @@ export function getGeminiAPI(config?: GeminiAPIConfig): GoogleAIAPI { return []; } - let ret = parts.map((part) => partToChatGeneration(part)); + // Citation and grounding information connected to each part / ChatGeneration + // to make sure they are available in downstream filters. + const candidate = (response?.data as GenerateContentResponseData)?.candidates?.[0]; + const groundingMetadata = candidate?.groundingMetadata; + const citationMetadata = candidate?.citationMetadata; + const groundingParts = groundingSupportByPart(groundingMetadata?.groundingSupports); + + const ret = parts.map((part, index) => { + const gen = partToChatGeneration(part); + if (!gen.generationInfo) { + gen.generationInfo = {}; + } + if (groundingMetadata) { + gen.generationInfo.groundingMetadata = groundingMetadata; + const groundingPart = groundingParts[index]; + if (groundingPart) { + gen.generationInfo.groundingSupport = groundingPart; + } + } + if (citationMetadata) { + gen.generationInfo.citationMetadata = citationMetadata; + } + return gen; + }); + + return ret; + } + + function responseToChatGenerations( + response: GoogleLLMResponse + ): ChatGeneration[] { + let ret = responseToGroundedChatGenerations(response); + + if (ret.length === 0) { + return []; + } + if (ret.every((item) => typeof item.message.content === "string")) { const combinedContent = ret.map((item) => item.message.content).join(""); const combinedText = ret.map((item) => item.text).join(""); diff --git a/libs/langchain-google-webauth/src/tests/chat_models.int.test.ts b/libs/langchain-google-webauth/src/tests/chat_models.int.test.ts index b406d61098a3..da3d3147c25f 100644 --- a/libs/langchain-google-webauth/src/tests/chat_models.int.test.ts +++ b/libs/langchain-google-webauth/src/tests/chat_models.int.test.ts @@ -22,7 +22,7 @@ import { ReadThroughBlobStore, SimpleWebBlobStore, } from "@langchain/google-common/experimental/utils/media_core"; -import {GeminiTool, GooglePlatformType, GoogleRequestLogger, GoogleRequestRecorder} from "@langchain/google-common"; +import {GeminiTool, GooglePlatformType, GoogleRequestRecorder} from "@langchain/google-common"; import {BaseCallbackHandler} from "@langchain/core/callbacks/base"; import {InMemoryStore} from "@langchain/core/stores"; import {BlobStoreGoogleCloudStorage} from "@langchain/google-gauth"; @@ -287,6 +287,8 @@ const testGeminiModelNames = [ {modelName: "gemini-1.5-flash-002", platformType: "gcp", apiVersion: "v1"}, {modelName: "gemini-2.0-flash-exp", platformType: "gai", apiVersion: "v1beta"}, {modelName: "gemini-2.0-flash-exp", platformType: "gcp", apiVersion: "v1"}, + + // Flash Thinking doesn't have functions or other features // {modelName: "gemini-2.0-flash-thinking-exp", platformType: "gai"}, // {modelName: "gemini-2.0-flash-thinking-exp", platformType: "gcp"}, ] @@ -305,18 +307,20 @@ describe.each(testGeminiModelNames)("Webauth ($platformType) Gemini Chat ($model let callbacks: BaseCallbackHandler[]; function newChatGoogle(fields?: ChatGoogleInput): ChatGoogle { + // const logger = new GoogleRequestLogger(); + recorder = new GoogleRequestRecorder(); + callbacks = [recorder]; + return new ChatGoogle({ modelName, platformType: platformType as GooglePlatformType, apiVersion, + callbacks, ...fields ?? {}, }) } beforeEach(async () => { - recorder = new GoogleRequestRecorder(); - callbacks = [recorder, new GoogleRequestLogger()]; - const delay = testGeminiModelDelay[modelName] ?? 0; if (delay) { console.log(`Delaying for ${delay}ms`) @@ -326,9 +330,7 @@ describe.each(testGeminiModelNames)("Webauth ($platformType) Gemini Chat ($model }); test("invoke", async () => { - const model = newChatGoogle({ - callbacks, - }); + const model = newChatGoogle(); const res = await model.invoke("What is 1 + 1?"); expect(res).toBeDefined(); expect(res._getType()).toEqual("ai"); @@ -339,6 +341,12 @@ describe.each(testGeminiModelNames)("Webauth ($platformType) Gemini Chat ($model expect(typeof aiMessage.content).toBe("string"); const text = aiMessage.content as string; expect(text).toMatch(/(1 + 1 (equals|is|=) )?2.? ?/); + + expect(res).toHaveProperty("response_metadata"); + expect(res.response_metadata).not.toHaveProperty("groundingMetadata"); + expect(res.response_metadata).not.toHaveProperty("groundingSupport"); + + console.log(recorder); }); test(`generate`, async () => { @@ -364,9 +372,7 @@ describe.each(testGeminiModelNames)("Webauth ($platformType) Gemini Chat ($model }); test("stream", async () => { - const model = newChatGoogle({ - callbacks, - }); + const model = newChatGoogle(); const input: BaseLanguageModelInput = new ChatPromptValue([ new SystemMessage( "You will reply to all requests to flip a coin with either H, indicating heads, or T, indicating tails." @@ -650,6 +656,7 @@ describe.each(testGeminiModelNames)("Webauth ($platformType) Gemini Chat ($model let tokensString = ""; const result = await modelWithStreaming.invoke("What is 1 + 1?", { callbacks: [ + ...callbacks, { handleLLMNewToken: (tok) => { totalTokenCount += 1; @@ -779,6 +786,9 @@ describe.each(testGeminiModelNames)("Webauth ($platformType) Gemini Chat ($model const result = await model.invoke("Who won the 2024 MLB World Series?"); expect(result.content as string).toContain("Dodgers"); + expect(result).toHaveProperty("response_metadata"); + expect(result.response_metadata).toHaveProperty("groundingMetadata"); + expect(result.response_metadata).toHaveProperty("groundingSupport"); }); test("Supports GoogleSearchTool", async () => { @@ -793,6 +803,9 @@ describe.each(testGeminiModelNames)("Webauth ($platformType) Gemini Chat ($model const result = await model.invoke("Who won the 2024 MLB World Series?"); expect(result.content as string).toContain("Dodgers"); + expect(result).toHaveProperty("response_metadata"); + expect(result.response_metadata).toHaveProperty("groundingMetadata"); + expect(result.response_metadata).toHaveProperty("groundingSupport"); }); test("Can stream GoogleSearchRetrievalTool", async () => { From 4cb0d50a71f97a98c9821b737fc7f02d9ba38be9 Mon Sep 17 00:00:00 2001 From: afirstenberg Date: Thu, 26 Dec 2024 20:02:53 -0500 Subject: [PATCH 6/8] formatting --- .../src/chat_models.ts | 15 +- .../src/tests/chat_models.test.ts | 5 +- .../src/tests/data/chat-6-mock.json | 36 +- libs/langchain-google-common/src/types.ts | 14 +- .../src/utils/gemini.ts | 21 +- .../src/tests/chat_models.int.test.ts | 12 +- .../src/tests/chat_models.int.test.ts | 995 +++++++++--------- 7 files changed, 557 insertions(+), 541 deletions(-) diff --git a/libs/langchain-google-common/src/chat_models.ts b/libs/langchain-google-common/src/chat_models.ts index a9cfbcb951ad..15bf21fd3c94 100644 --- a/libs/langchain-google-common/src/chat_models.ts +++ b/libs/langchain-google-common/src/chat_models.ts @@ -98,7 +98,10 @@ export class ChatConnection extends AbstractGoogleLLMConnection< return true; } - computeGoogleSearchToolAdjustmentFromModel(): Exclude { + computeGoogleSearchToolAdjustmentFromModel(): Exclude< + GoogleSearchToolSetting, + boolean + > { if (this.modelName.startsWith("gemini-1.0")) { return "googleSearchRetrieval"; } else if (this.modelName.startsWith("gemini-1.5")) { @@ -108,7 +111,9 @@ export class ChatConnection extends AbstractGoogleLLMConnection< } } - computeGoogleSearchToolAdjustment(apiConfig: GeminiAPIConfig): Exclude { + computeGoogleSearchToolAdjustment( + apiConfig: GeminiAPIConfig + ): Exclude { const adj = apiConfig.googleSearchToolAdjustment; if (adj === undefined || adj === true) { return this.computeGoogleSearchToolAdjustmentFromModel(); @@ -118,8 +123,10 @@ export class ChatConnection extends AbstractGoogleLLMConnection< } buildGeminiAPI(): GoogleAIAPI { - const apiConfig: GeminiAPIConfig = this.apiConfig as GeminiAPIConfig ?? {}; - const googleSearchToolAdjustment = this.computeGoogleSearchToolAdjustment(apiConfig); + const apiConfig: GeminiAPIConfig = + (this.apiConfig as GeminiAPIConfig) ?? {}; + const googleSearchToolAdjustment = + this.computeGoogleSearchToolAdjustment(apiConfig); const geminiConfig: GeminiAPIConfig = { useSystemInstruction: this.useSystemInstruction, googleSearchToolAdjustment, diff --git a/libs/langchain-google-common/src/tests/chat_models.test.ts b/libs/langchain-google-common/src/tests/chat_models.test.ts index 59998cf24768..5726d9fd445e 100644 --- a/libs/langchain-google-common/src/tests/chat_models.test.ts +++ b/libs/langchain-google-common/src/tests/chat_models.test.ts @@ -1135,7 +1135,9 @@ describe("Mock ChatGoogle - Gemini", () => { expect(result).toHaveProperty("response_metadata"); expect(result.response_metadata).toHaveProperty("groundingMetadata"); expect(result.response_metadata).toHaveProperty("groundingSupport"); - expect(Array.isArray(result.response_metadata.groundingSupport)).toEqual(true); + expect(Array.isArray(result.response_metadata.groundingSupport)).toEqual( + true + ); expect(result.response_metadata.groundingSupport).toHaveLength(4); }); @@ -1248,7 +1250,6 @@ describe("Mock ChatGoogle - Gemini", () => { expect(record.opts.data.tools[0]).toHaveProperty("googleSearch"); }); - }); describe("Mock ChatGoogle - Anthropic", () => { diff --git a/libs/langchain-google-common/src/tests/data/chat-6-mock.json b/libs/langchain-google-common/src/tests/data/chat-6-mock.json index 65568fb1810a..796fdcf9bcee 100644 --- a/libs/langchain-google-common/src/tests/data/chat-6-mock.json +++ b/libs/langchain-google-common/src/tests/data/chat-6-mock.json @@ -40,12 +40,8 @@ "endIndex": 100, "text": "The Los Angeles Dodgers won the 2024 World Series, defeating the New York Yankees 4-1 in the series." }, - "groundingChunkIndices": [ - 0 - ], - "confidenceScores": [ - 0.95898277 - ] + "groundingChunkIndices": [0], + "confidenceScores": [0.95898277] }, { "segment": { @@ -53,12 +49,8 @@ "endIndex": 377, "text": "It was also their first World Series win in a full season since 1988." }, - "groundingChunkIndices": [ - 1 - ], - "confidenceScores": [ - 0.96841997 - ] + "groundingChunkIndices": [1], + "confidenceScores": [0.96841997] }, { "segment": { @@ -66,12 +58,8 @@ "endIndex": 508, "text": "Mookie Betts earned his third World Series ring (2018, 2020, and 2024), becoming the only active player with three championships." }, - "groundingChunkIndices": [ - 2 - ], - "confidenceScores": [ - 0.99043523 - ] + "groundingChunkIndices": [2], + "confidenceScores": [0.99043523] }, { "segment": { @@ -79,17 +67,11 @@ "endIndex": 611, "text": "Shohei Ohtani, in his first year with the Dodgers, also experienced his first post-season appearance." }, - "groundingChunkIndices": [ - 0 - ], - "confidenceScores": [ - 0.95767003 - ] + "groundingChunkIndices": [0], + "confidenceScores": [0.95767003] } ], - "webSearchQueries": [ - "2024 MLB World Series winner" - ] + "webSearchQueries": ["2024 MLB World Series winner"] }, "avgLogprobs": -0.040494912748883484 } diff --git a/libs/langchain-google-common/src/types.ts b/libs/langchain-google-common/src/types.ts index 2706d190cc78..3b702cda6f87 100644 --- a/libs/langchain-google-common/src/types.ts +++ b/libs/langchain-google-common/src/types.ts @@ -313,9 +313,9 @@ export interface GeminiCitation { } export interface GoogleTypeDate { - year: number; // 1-9999 or 0 to specify a date without a year + year: number; // 1-9999 or 0 to specify a date without a year month: number; // 1-12 or 0 to specify a year without a month and day - day: number; // Must be from 1 to 31 and valid for the year and month, or 0 to specify a year by itself or a year and month where the day isn't significant + day: number; // Must be from 1 to 31 and valid for the year and month, or 0 to specify a year by itself or a year and month where the day isn't significant } export interface GeminiGroundingMetadata { @@ -328,7 +328,7 @@ export interface GeminiGroundingMetadata { export interface GeminiSearchEntryPoint { renderedContent?: string; - sdkBlob?: string; // Base64 encoded JSON representing array of tuple. + sdkBlob?: string; // Base64 encoded JSON representing array of tuple. } export interface GeminiGroundingChunk { @@ -378,8 +378,8 @@ export interface GeminiContent { */ export interface GeminiTool { functionDeclarations?: GeminiFunctionDeclaration[]; - googleSearchRetrieval?: GoogleSearchRetrieval; // Gemini-1.5 - googleSearch?: GoogleSearch; // Gemini-2.0 + googleSearchRetrieval?: GoogleSearchRetrieval; // Gemini-1.5 + googleSearch?: GoogleSearch; // Gemini-2.0 retrieval?: VertexAIRetrieval; } @@ -395,13 +395,13 @@ export type GoogleSearchToolSetting = export const GeminiSearchToolAttributes = [ "googleSearchRetrieval", "googleSearch", -] +]; export const GeminiToolAttributes = [ "functionDeclaration", "retrieval", ...GeminiSearchToolAttributes, -] +]; export interface GoogleSearchRetrieval { dynamicRetrievalConfig?: { diff --git a/libs/langchain-google-common/src/utils/gemini.ts b/libs/langchain-google-common/src/utils/gemini.ts index e6e122e85d3f..7f532983d6a4 100644 --- a/libs/langchain-google-common/src/utils/gemini.ts +++ b/libs/langchain-google-common/src/utils/gemini.ts @@ -36,7 +36,8 @@ import type { GoogleAISafetyHandler, GeminiPartFunctionCall, GoogleAIAPI, - GeminiAPIConfig, GeminiGroundingSupport, + GeminiAPIConfig, + GeminiGroundingSupport, } from "../types.js"; import { GoogleAISafetyError } from "./safety.js"; import { MediaBlob } from "../experimental/utils/media_core.js"; @@ -757,7 +758,7 @@ export function getGeminiAPI(config?: GeminiAPIConfig): GoogleAIAPI { ): GeminiGroundingSupport[][] { const ret: GeminiGroundingSupport[][] = []; - if (!groundingSupports || groundingSupports.length === 0){ + if (!groundingSupports || groundingSupports.length === 0) { return []; } @@ -769,7 +770,6 @@ export function getGeminiAPI(config?: GeminiAPIConfig): GoogleAIAPI { } else { ret[partIndex] = [groundingSupport]; } - }); return ret; @@ -786,10 +786,13 @@ export function getGeminiAPI(config?: GeminiAPIConfig): GoogleAIAPI { // Citation and grounding information connected to each part / ChatGeneration // to make sure they are available in downstream filters. - const candidate = (response?.data as GenerateContentResponseData)?.candidates?.[0]; + const candidate = (response?.data as GenerateContentResponseData) + ?.candidates?.[0]; const groundingMetadata = candidate?.groundingMetadata; const citationMetadata = candidate?.citationMetadata; - const groundingParts = groundingSupportByPart(groundingMetadata?.groundingSupports); + const groundingParts = groundingSupportByPart( + groundingMetadata?.groundingSupports + ); const ret = parts.map((part, index) => { const gen = partToChatGeneration(part); @@ -1080,7 +1083,7 @@ export function getGeminiAPI(config?: GeminiAPIConfig): GoogleAIAPI { function searchToolName(tool: GeminiTool): string | undefined { for (const name of GeminiSearchToolAttributes) { if (name in tool) { - return name + return name; } } return undefined; @@ -1092,7 +1095,7 @@ export function getGeminiAPI(config?: GeminiAPIConfig): GoogleAIAPI { if (orig && adj && adj !== orig) { return { [adj as string]: {}, - } + }; } else { return tool; } @@ -1108,13 +1111,13 @@ export function getGeminiAPI(config?: GeminiAPIConfig): GoogleAIAPI { // Gemini Tools may be normalized to different tool names const langChainTools: StructuredToolParams[] = []; const otherTools: GeminiTool[] = []; - tools.forEach(tool => { + tools.forEach((tool) => { if (isLangChainTool(tool)) { langChainTools.push(tool); } else { otherTools.push(cleanGeminiTool(tool as GeminiTool)); } - }) + }); const result: GeminiTool[] = [...otherTools]; diff --git a/libs/langchain-google-vertexai/src/tests/chat_models.int.test.ts b/libs/langchain-google-vertexai/src/tests/chat_models.int.test.ts index 49faa348a82f..6d1606614bd1 100644 --- a/libs/langchain-google-vertexai/src/tests/chat_models.int.test.ts +++ b/libs/langchain-google-vertexai/src/tests/chat_models.int.test.ts @@ -67,7 +67,7 @@ const testGeminiModelNames = [ ["gemini-1.5-flash-002"], ["gemini-2.0-flash-exp"], // ["gemini-2.0-flash-thinking-exp-1219"], -] +]; /* * Some models may have usage quotas still. @@ -76,7 +76,7 @@ const testGeminiModelNames = [ const testGeminiModelDelay: Record = { "gemini-2.0-flash-exp": 5000, "gemini-2.0-flash-thinking-exp-1219": 5000, -} +}; describe.each(testGeminiModelNames)("GAuth Gemini Chat (%s)", (modelName) => { let recorder: GoogleRequestRecorder; @@ -88,9 +88,9 @@ describe.each(testGeminiModelNames)("GAuth Gemini Chat (%s)", (modelName) => { const delay = testGeminiModelDelay[modelName] ?? 0; if (delay) { - console.log(`Delaying for ${delay}ms`) + console.log(`Delaying for ${delay}ms`); // eslint-disable-next-line no-promise-executor-return - await new Promise(resolve => setTimeout(resolve,delay)); + await new Promise((resolve) => setTimeout(resolve, delay)); } }); @@ -573,8 +573,7 @@ describe.each(testGeminiModelNames)("GAuth Gemini Chat (%s)", (modelName) => { test("Supports GoogleSearchTool", async () => { const searchTool: GeminiTool = { - googleSearch: { - }, + googleSearch: {}, }; const model = new ChatVertexAI({ modelName, @@ -611,7 +610,6 @@ describe.each(testGeminiModelNames)("GAuth Gemini Chat (%s)", (modelName) => { } expect(finalMsg.content as string).toContain("Dodgers"); }); - }); describe("GAuth Anthropic Chat", () => { diff --git a/libs/langchain-google-webauth/src/tests/chat_models.int.test.ts b/libs/langchain-google-webauth/src/tests/chat_models.int.test.ts index da3d3147c25f..5379533c7181 100644 --- a/libs/langchain-google-webauth/src/tests/chat_models.int.test.ts +++ b/libs/langchain-google-webauth/src/tests/chat_models.int.test.ts @@ -6,7 +6,8 @@ import { AIMessage, AIMessageChunk, BaseMessage, - BaseMessageChunk, BaseMessageLike, + BaseMessageChunk, + BaseMessageLike, HumanMessage, HumanMessageChunk, MessageContentComplex, @@ -22,15 +23,22 @@ import { ReadThroughBlobStore, SimpleWebBlobStore, } from "@langchain/google-common/experimental/utils/media_core"; -import {GeminiTool, GooglePlatformType, GoogleRequestRecorder} from "@langchain/google-common"; -import {BaseCallbackHandler} from "@langchain/core/callbacks/base"; -import {InMemoryStore} from "@langchain/core/stores"; -import {BlobStoreGoogleCloudStorage} from "@langchain/google-gauth"; -import {GoogleCloudStorageUri} from "@langchain/google-common/experimental/media"; -import {concat} from "@langchain/core/utils/stream"; +import { + GeminiTool, + GooglePlatformType, + GoogleRequestRecorder, +} from "@langchain/google-common"; +import { BaseCallbackHandler } from "@langchain/core/callbacks/base"; +import { InMemoryStore } from "@langchain/core/stores"; +import { BlobStoreGoogleCloudStorage } from "@langchain/google-gauth"; +import { GoogleCloudStorageUri } from "@langchain/google-common/experimental/media"; +import { concat } from "@langchain/core/utils/stream"; import fs from "fs/promises"; -import {ChatPromptTemplate, MessagesPlaceholder} from "@langchain/core/prompts"; -import {ChatGoogle, ChatGoogleInput} from "../chat_models.js"; +import { + ChatPromptTemplate, + MessagesPlaceholder, +} from "@langchain/core/prompts"; +import { ChatGoogle, ChatGoogleInput } from "../chat_models.js"; import { BlobStoreAIStudioFile } from "../media.js"; class WeatherTool extends StructuredTool { @@ -281,17 +289,29 @@ const calculatorTool = tool((_) => "no-op", { * and on which platforms? */ const testGeminiModelNames = [ - {modelName: "gemini-1.5-pro-002", platformType: "gai", apiVersion: "v1beta"}, - {modelName: "gemini-1.5-pro-002", platformType: "gcp", apiVersion: "v1"}, - {modelName: "gemini-1.5-flash-002", platformType: "gai", apiVersion: "v1beta"}, - {modelName: "gemini-1.5-flash-002", platformType: "gcp", apiVersion: "v1"}, - {modelName: "gemini-2.0-flash-exp", platformType: "gai", apiVersion: "v1beta"}, - {modelName: "gemini-2.0-flash-exp", platformType: "gcp", apiVersion: "v1"}, + { + modelName: "gemini-1.5-pro-002", + platformType: "gai", + apiVersion: "v1beta", + }, + { modelName: "gemini-1.5-pro-002", platformType: "gcp", apiVersion: "v1" }, + { + modelName: "gemini-1.5-flash-002", + platformType: "gai", + apiVersion: "v1beta", + }, + { modelName: "gemini-1.5-flash-002", platformType: "gcp", apiVersion: "v1" }, + { + modelName: "gemini-2.0-flash-exp", + platformType: "gai", + apiVersion: "v1beta", + }, + { modelName: "gemini-2.0-flash-exp", platformType: "gcp", apiVersion: "v1" }, // Flash Thinking doesn't have functions or other features // {modelName: "gemini-2.0-flash-thinking-exp", platformType: "gai"}, // {modelName: "gemini-2.0-flash-thinking-exp", platformType: "gcp"}, -] +]; /* * Some models may have usage quotas still. @@ -300,537 +320,542 @@ const testGeminiModelNames = [ const testGeminiModelDelay: Record = { "gemini-2.0-flash-exp": 10000, "gemini-2.0-flash-thinking-exp-1219": 10000, -} - -describe.each(testGeminiModelNames)("Webauth ($platformType) Gemini Chat ($modelName)", ({modelName, platformType, apiVersion}) => { - let recorder: GoogleRequestRecorder; - let callbacks: BaseCallbackHandler[]; - - function newChatGoogle(fields?: ChatGoogleInput): ChatGoogle { - // const logger = new GoogleRequestLogger(); - recorder = new GoogleRequestRecorder(); - callbacks = [recorder]; - - return new ChatGoogle({ - modelName, - platformType: platformType as GooglePlatformType, - apiVersion, - callbacks, - ...fields ?? {}, - }) - } - - beforeEach(async () => { - const delay = testGeminiModelDelay[modelName] ?? 0; - if (delay) { - console.log(`Delaying for ${delay}ms`) - // eslint-disable-next-line no-promise-executor-return - await new Promise(resolve => setTimeout(resolve,delay)); +}; + +describe.each(testGeminiModelNames)( + "Webauth ($platformType) Gemini Chat ($modelName)", + ({ modelName, platformType, apiVersion }) => { + let recorder: GoogleRequestRecorder; + let callbacks: BaseCallbackHandler[]; + + function newChatGoogle(fields?: ChatGoogleInput): ChatGoogle { + // const logger = new GoogleRequestLogger(); + recorder = new GoogleRequestRecorder(); + callbacks = [recorder]; + + return new ChatGoogle({ + modelName, + platformType: platformType as GooglePlatformType, + apiVersion, + callbacks, + ...(fields ?? {}), + }); } - }); - test("invoke", async () => { - const model = newChatGoogle(); - const res = await model.invoke("What is 1 + 1?"); - expect(res).toBeDefined(); - expect(res._getType()).toEqual("ai"); + beforeEach(async () => { + const delay = testGeminiModelDelay[modelName] ?? 0; + if (delay) { + console.log(`Delaying for ${delay}ms`); + // eslint-disable-next-line no-promise-executor-return + await new Promise((resolve) => setTimeout(resolve, delay)); + } + }); - const aiMessage = res as AIMessageChunk; - expect(aiMessage.content).toBeDefined(); + test("invoke", async () => { + const model = newChatGoogle(); + const res = await model.invoke("What is 1 + 1?"); + expect(res).toBeDefined(); + expect(res._getType()).toEqual("ai"); - expect(typeof aiMessage.content).toBe("string"); - const text = aiMessage.content as string; - expect(text).toMatch(/(1 + 1 (equals|is|=) )?2.? ?/); + const aiMessage = res as AIMessageChunk; + expect(aiMessage.content).toBeDefined(); - expect(res).toHaveProperty("response_metadata"); - expect(res.response_metadata).not.toHaveProperty("groundingMetadata"); - expect(res.response_metadata).not.toHaveProperty("groundingSupport"); + expect(typeof aiMessage.content).toBe("string"); + const text = aiMessage.content as string; + expect(text).toMatch(/(1 + 1 (equals|is|=) )?2.? ?/); - console.log(recorder); - }); + expect(res).toHaveProperty("response_metadata"); + expect(res.response_metadata).not.toHaveProperty("groundingMetadata"); + expect(res.response_metadata).not.toHaveProperty("groundingSupport"); - test(`generate`, async () => { - const model = newChatGoogle(); - const messages: BaseMessage[] = [ - new SystemMessage( - "You will reply to all requests to flip a coin with either H, indicating heads, or T, indicating tails." - ), - new HumanMessage("Flip it"), - new AIMessage("T"), - new HumanMessage("Flip the coin again"), - ]; - const res = await model.predictMessages(messages); - expect(res).toBeDefined(); - expect(res._getType()).toEqual("ai"); + console.log(recorder); + }); - const aiMessage = res as AIMessageChunk; - expect(aiMessage.content).toBeDefined(); + test(`generate`, async () => { + const model = newChatGoogle(); + const messages: BaseMessage[] = [ + new SystemMessage( + "You will reply to all requests to flip a coin with either H, indicating heads, or T, indicating tails." + ), + new HumanMessage("Flip it"), + new AIMessage("T"), + new HumanMessage("Flip the coin again"), + ]; + const res = await model.predictMessages(messages); + expect(res).toBeDefined(); + expect(res._getType()).toEqual("ai"); - expect(typeof aiMessage.content).toBe("string"); - const text = aiMessage.content as string; - expect(["H", "T"]).toContainEqual(text.trim()); - }); + const aiMessage = res as AIMessageChunk; + expect(aiMessage.content).toBeDefined(); - test("stream", async () => { - const model = newChatGoogle(); - const input: BaseLanguageModelInput = new ChatPromptValue([ - new SystemMessage( - "You will reply to all requests to flip a coin with either H, indicating heads, or T, indicating tails." - ), - new HumanMessage("Flip it"), - new AIMessage("T"), - new HumanMessage("Flip the coin again"), - ]); - const res = await model.stream(input); - const resArray: BaseMessageChunk[] = []; - for await (const chunk of res) { - resArray.push(chunk); - } - expect(resArray).toBeDefined(); - expect(resArray.length).toBeGreaterThanOrEqual(1); + expect(typeof aiMessage.content).toBe("string"); + const text = aiMessage.content as string; + expect(["H", "T"]).toContainEqual(text.trim()); + }); - const lastChunk = resArray[resArray.length - 1]; - expect(lastChunk).toBeDefined(); - expect(lastChunk._getType()).toEqual("ai"); - }); + test("stream", async () => { + const model = newChatGoogle(); + const input: BaseLanguageModelInput = new ChatPromptValue([ + new SystemMessage( + "You will reply to all requests to flip a coin with either H, indicating heads, or T, indicating tails." + ), + new HumanMessage("Flip it"), + new AIMessage("T"), + new HumanMessage("Flip the coin again"), + ]); + const res = await model.stream(input); + const resArray: BaseMessageChunk[] = []; + for await (const chunk of res) { + resArray.push(chunk); + } + expect(resArray).toBeDefined(); + expect(resArray.length).toBeGreaterThanOrEqual(1); - test("function", async () => { - const tools: GeminiTool[] = [ - { - functionDeclarations: [ - { - name: "test", - description: - "Run a test with a specific name and get if it passed or failed", - parameters: { - type: "object", - properties: { - testName: { - type: "string", - description: "The name of the test that should be run.", + const lastChunk = resArray[resArray.length - 1]; + expect(lastChunk).toBeDefined(); + expect(lastChunk._getType()).toEqual("ai"); + }); + + test("function", async () => { + const tools: GeminiTool[] = [ + { + functionDeclarations: [ + { + name: "test", + description: + "Run a test with a specific name and get if it passed or failed", + parameters: { + type: "object", + properties: { + testName: { + type: "string", + description: "The name of the test that should be run.", + }, }, + required: ["testName"], }, - required: ["testName"], }, - }, - ], - }, - ]; - const model = newChatGoogle().bind({ - tools, + ], + }, + ]; + const model = newChatGoogle().bind({ + tools, + }); + const result = await model.invoke("Run a test on the cobalt project"); + expect(result).toHaveProperty("content"); + expect(result.content).toBe(""); + const args = result?.lc_kwargs?.additional_kwargs; + expect(args).toBeDefined(); + expect(args).toHaveProperty("tool_calls"); + expect(Array.isArray(args.tool_calls)).toBeTruthy(); + expect(args.tool_calls).toHaveLength(1); + const call = args.tool_calls[0]; + expect(call).toHaveProperty("type"); + expect(call.type).toBe("function"); + expect(call).toHaveProperty("function"); + const func = call.function; + expect(func).toBeDefined(); + expect(func).toHaveProperty("name"); + expect(func.name).toBe("test"); + expect(func).toHaveProperty("arguments"); + expect(typeof func.arguments).toBe("string"); + expect(func.arguments.replaceAll("\n", "")).toBe('{"testName":"cobalt"}'); }); - const result = await model.invoke("Run a test on the cobalt project"); - expect(result).toHaveProperty("content"); - expect(result.content).toBe(""); - const args = result?.lc_kwargs?.additional_kwargs; - expect(args).toBeDefined(); - expect(args).toHaveProperty("tool_calls"); - expect(Array.isArray(args.tool_calls)).toBeTruthy(); - expect(args.tool_calls).toHaveLength(1); - const call = args.tool_calls[0]; - expect(call).toHaveProperty("type"); - expect(call.type).toBe("function"); - expect(call).toHaveProperty("function"); - const func = call.function; - expect(func).toBeDefined(); - expect(func).toHaveProperty("name"); - expect(func.name).toBe("test"); - expect(func).toHaveProperty("arguments"); - expect(typeof func.arguments).toBe("string"); - expect(func.arguments.replaceAll("\n", "")).toBe('{"testName":"cobalt"}'); - }); - test("function reply", async () => { - const tools: GeminiTool[] = [ - { - functionDeclarations: [ - { - name: "test", - description: - "Run a test with a specific name and get if it passed or failed", - parameters: { - type: "object", - properties: { - testName: { - type: "string", - description: "The name of the test that should be run.", + test("function reply", async () => { + const tools: GeminiTool[] = [ + { + functionDeclarations: [ + { + name: "test", + description: + "Run a test with a specific name and get if it passed or failed", + parameters: { + type: "object", + properties: { + testName: { + type: "string", + description: "The name of the test that should be run.", + }, }, + required: ["testName"], }, - required: ["testName"], }, - }, - ], - }, - ]; - const model = newChatGoogle().bind({ - tools, - }); - const toolResult = { - testPassed: true, - }; - const messages: BaseMessageLike[] = [ - new HumanMessage("Run a test on the cobalt project."), - new AIMessage("", { - tool_calls: [ - { - id: "test", - type: "function", - function: { - name: "test", - arguments: '{"testName":"cobalt"}', + ], + }, + ]; + const model = newChatGoogle().bind({ + tools, + }); + const toolResult = { + testPassed: true, + }; + const messages: BaseMessageLike[] = [ + new HumanMessage("Run a test on the cobalt project."), + new AIMessage("", { + tool_calls: [ + { + id: "test", + type: "function", + function: { + name: "test", + arguments: '{"testName":"cobalt"}', + }, }, - }, - ], - }), - new ToolMessage(JSON.stringify(toolResult), "test"), - ]; - const res = await model.stream(messages); - const resArray: BaseMessageChunk[] = []; - for await (const chunk of res) { - resArray.push(chunk); - } - // console.log(JSON.stringify(resArray, null, 2)); - }); + ], + }), + new ToolMessage(JSON.stringify(toolResult), "test"), + ]; + const res = await model.stream(messages); + const resArray: BaseMessageChunk[] = []; + for await (const chunk of res) { + resArray.push(chunk); + } + // console.log(JSON.stringify(resArray, null, 2)); + }); - test("withStructuredOutput", async () => { - const tool = { - name: "get_weather", - description: - "Get the weather of a specific location and return the temperature in Celsius.", - parameters: { - type: "object", - properties: { - location: { - type: "string", - description: "The name of city to get the weather for.", + test("withStructuredOutput", async () => { + const tool = { + name: "get_weather", + description: + "Get the weather of a specific location and return the temperature in Celsius.", + parameters: { + type: "object", + properties: { + location: { + type: "string", + description: "The name of city to get the weather for.", + }, }, + required: ["location"], }, - required: ["location"], - }, - }; - const model = newChatGoogle().withStructuredOutput(tool); - const result = await model.invoke("What is the weather in Paris?"); - expect(result).toHaveProperty("location"); - }); - - test("media - fileData", async () => { - class MemStore extends InMemoryStore { - get length() { - return Object.keys(this.store).length; - } - } - const aliasMemory = new MemStore(); - const aliasStore = new BackedBlobStore({ - backingStore: aliasMemory, - defaultFetchOptions: { - actionIfBlobMissing: undefined, - }, - }); - const backingStore = new BlobStoreGoogleCloudStorage({ - uriPrefix: new GoogleCloudStorageUri("gs://test-langchainjs/mediatest/"), - defaultStoreOptions: { - actionIfInvalid: "prefixPath", - }, - }); - const blobStore = new ReadThroughBlobStore({ - baseStore: aliasStore, - backingStore, - }); - const resolver = new SimpleWebBlobStore(); - const mediaManager = new MediaManager({ - store: blobStore, - resolvers: [resolver], - }); - const model = newChatGoogle({ - apiConfig: { - mediaManager, - }, + }; + const model = newChatGoogle().withStructuredOutput(tool); + const result = await model.invoke("What is the weather in Paris?"); + expect(result).toHaveProperty("location"); }); - const message: MessageContentComplex[] = [ - { - type: "text", - text: "What is in this image?", - }, - { - type: "media", - fileUri: "https://js.langchain.com/v0.2/img/brand/wordmark.png", - }, - ]; + test("media - fileData", async () => { + class MemStore extends InMemoryStore { + get length() { + return Object.keys(this.store).length; + } + } + const aliasMemory = new MemStore(); + const aliasStore = new BackedBlobStore({ + backingStore: aliasMemory, + defaultFetchOptions: { + actionIfBlobMissing: undefined, + }, + }); + const backingStore = new BlobStoreGoogleCloudStorage({ + uriPrefix: new GoogleCloudStorageUri( + "gs://test-langchainjs/mediatest/" + ), + defaultStoreOptions: { + actionIfInvalid: "prefixPath", + }, + }); + const blobStore = new ReadThroughBlobStore({ + baseStore: aliasStore, + backingStore, + }); + const resolver = new SimpleWebBlobStore(); + const mediaManager = new MediaManager({ + store: blobStore, + resolvers: [resolver], + }); + const model = newChatGoogle({ + apiConfig: { + mediaManager, + }, + }); - const messages: BaseMessage[] = [ - new HumanMessageChunk({ content: message }), - ]; + const message: MessageContentComplex[] = [ + { + type: "text", + text: "What is in this image?", + }, + { + type: "media", + fileUri: "https://js.langchain.com/v0.2/img/brand/wordmark.png", + }, + ]; - try { - const res = await model.invoke(messages); + const messages: BaseMessage[] = [ + new HumanMessageChunk({ content: message }), + ]; - console.log(res); + try { + const res = await model.invoke(messages); - expect(res).toBeDefined(); - expect(res._getType()).toEqual("ai"); + console.log(res); - const aiMessage = res as AIMessageChunk; - expect(aiMessage.content).toBeDefined(); + expect(res).toBeDefined(); + expect(res._getType()).toEqual("ai"); - expect(typeof aiMessage.content).toBe("string"); - const text = aiMessage.content as string; - expect(text).toMatch(/LangChain/); - } catch (e) { - console.error(e); - throw e; - } - }); + const aiMessage = res as AIMessageChunk; + expect(aiMessage.content).toBeDefined(); - test("Stream token count usage_metadata", async () => { - const model = newChatGoogle({ - temperature: 0, - maxOutputTokens: 10, - }); - let res: AIMessageChunk | null = null; - for await (const chunk of await model.stream( - "Why is the sky blue? Be concise." - )) { - if (!res) { - res = chunk; - } else { - res = res.concat(chunk); + expect(typeof aiMessage.content).toBe("string"); + const text = aiMessage.content as string; + expect(text).toMatch(/LangChain/); + } catch (e) { + console.error(e); + throw e; } - } - // console.log(res); - expect(res?.usage_metadata).toBeDefined(); - if (!res?.usage_metadata) { - return; - } - expect(res.usage_metadata.input_tokens).toBeGreaterThan(1); - expect(res.usage_metadata.output_tokens).toBeGreaterThan(1); - expect(res.usage_metadata.total_tokens).toBe( - res.usage_metadata.input_tokens + res.usage_metadata.output_tokens - ); - }); - - test("streamUsage excludes token usage", async () => { - const model = newChatGoogle({ - temperature: 0, - streamUsage: false, }); - let res: AIMessageChunk | null = null; - for await (const chunk of await model.stream( - "Why is the sky blue? Be concise." - )) { - if (!res) { - res = chunk; - } else { - res = res.concat(chunk); - } - } - // console.log(res); - expect(res?.usage_metadata).not.toBeDefined(); - }); - test("Invoke token count usage_metadata", async () => { - const model = newChatGoogle({ - temperature: 0, - maxOutputTokens: 10, + test("Stream token count usage_metadata", async () => { + const model = newChatGoogle({ + temperature: 0, + maxOutputTokens: 10, + }); + let res: AIMessageChunk | null = null; + for await (const chunk of await model.stream( + "Why is the sky blue? Be concise." + )) { + if (!res) { + res = chunk; + } else { + res = res.concat(chunk); + } + } + // console.log(res); + expect(res?.usage_metadata).toBeDefined(); + if (!res?.usage_metadata) { + return; + } + expect(res.usage_metadata.input_tokens).toBeGreaterThan(1); + expect(res.usage_metadata.output_tokens).toBeGreaterThan(1); + expect(res.usage_metadata.total_tokens).toBe( + res.usage_metadata.input_tokens + res.usage_metadata.output_tokens + ); }); - const res = await model.invoke("Why is the sky blue? Be concise."); - // console.log(res); - expect(res?.usage_metadata).toBeDefined(); - if (!res?.usage_metadata) { - return; - } - expect(res.usage_metadata.input_tokens).toBeGreaterThan(1); - expect(res.usage_metadata.output_tokens).toBeGreaterThan(1); - expect(res.usage_metadata.total_tokens).toBe( - res.usage_metadata.input_tokens + res.usage_metadata.output_tokens - ); - }); - test("Streaming true constructor param will stream", async () => { - const modelWithStreaming = newChatGoogle({ - maxOutputTokens: 50, - streaming: true, + test("streamUsage excludes token usage", async () => { + const model = newChatGoogle({ + temperature: 0, + streamUsage: false, + }); + let res: AIMessageChunk | null = null; + for await (const chunk of await model.stream( + "Why is the sky blue? Be concise." + )) { + if (!res) { + res = chunk; + } else { + res = res.concat(chunk); + } + } + // console.log(res); + expect(res?.usage_metadata).not.toBeDefined(); }); - let totalTokenCount = 0; - let tokensString = ""; - const result = await modelWithStreaming.invoke("What is 1 + 1?", { - callbacks: [ - ...callbacks, - { - handleLLMNewToken: (tok) => { - totalTokenCount += 1; - tokensString += tok; - }, - }, - ], + test("Invoke token count usage_metadata", async () => { + const model = newChatGoogle({ + temperature: 0, + maxOutputTokens: 10, + }); + const res = await model.invoke("Why is the sky blue? Be concise."); + // console.log(res); + expect(res?.usage_metadata).toBeDefined(); + if (!res?.usage_metadata) { + return; + } + expect(res.usage_metadata.input_tokens).toBeGreaterThan(1); + expect(res.usage_metadata.output_tokens).toBeGreaterThan(1); + expect(res.usage_metadata.total_tokens).toBe( + res.usage_metadata.input_tokens + res.usage_metadata.output_tokens + ); }); - expect(result).toBeDefined(); - expect(result.content).toBe(tokensString); + test("Streaming true constructor param will stream", async () => { + const modelWithStreaming = newChatGoogle({ + maxOutputTokens: 50, + streaming: true, + }); + + let totalTokenCount = 0; + let tokensString = ""; + const result = await modelWithStreaming.invoke("What is 1 + 1?", { + callbacks: [ + ...callbacks, + { + handleLLMNewToken: (tok) => { + totalTokenCount += 1; + tokensString += tok; + }, + }, + ], + }); - expect(totalTokenCount).toBeGreaterThan(1); - }); + expect(result).toBeDefined(); + expect(result.content).toBe(tokensString); - test("Can force a model to invoke a tool", async () => { - const model = newChatGoogle(); - const modelWithTools = model.bind({ - tools: [calculatorTool, weatherTool], - tool_choice: "calculator", + expect(totalTokenCount).toBeGreaterThan(1); }); - const result = await modelWithTools.invoke( - "Whats the weather like in paris today? What's 1836 plus 7262?" - ); - - expect(result.tool_calls).toHaveLength(1); - expect(result.tool_calls?.[0]).toBeDefined(); - if (!result.tool_calls?.[0]) return; - expect(result.tool_calls?.[0].name).toBe("calculator"); - expect(result.tool_calls?.[0].args).toHaveProperty("expression"); - }); + test("Can force a model to invoke a tool", async () => { + const model = newChatGoogle(); + const modelWithTools = model.bind({ + tools: [calculatorTool, weatherTool], + tool_choice: "calculator", + }); + + const result = await modelWithTools.invoke( + "Whats the weather like in paris today? What's 1836 plus 7262?" + ); + + expect(result.tool_calls).toHaveLength(1); + expect(result.tool_calls?.[0]).toBeDefined(); + if (!result.tool_calls?.[0]) return; + expect(result.tool_calls?.[0].name).toBe("calculator"); + expect(result.tool_calls?.[0].args).toHaveProperty("expression"); + }); - test(`stream tools`, async () => { - const model = newChatGoogle(); + test(`stream tools`, async () => { + const model = newChatGoogle(); - const weatherTool = tool( - (_) => "The weather in San Francisco today is 18 degrees and sunny.", - { - name: "current_weather_tool", - description: "Get the current weather for a given location.", - schema: z.object({ - location: z.string().describe("The location to get the weather for."), - }), + const weatherTool = tool( + (_) => "The weather in San Francisco today is 18 degrees and sunny.", + { + name: "current_weather_tool", + description: "Get the current weather for a given location.", + schema: z.object({ + location: z + .string() + .describe("The location to get the weather for."), + }), + } + ); + + const modelWithTools = model.bindTools([weatherTool]); + const stream = await modelWithTools.stream( + "Whats the weather like today in San Francisco?" + ); + let finalChunk: AIMessageChunk | undefined; + for await (const chunk of stream) { + finalChunk = !finalChunk ? chunk : concat(finalChunk, chunk); } - ); - const modelWithTools = model.bindTools([weatherTool]); - const stream = await modelWithTools.stream( - "Whats the weather like today in San Francisco?" - ); - let finalChunk: AIMessageChunk | undefined; - for await (const chunk of stream) { - finalChunk = !finalChunk ? chunk : concat(finalChunk, chunk); - } + expect(finalChunk).toBeDefined(); + if (!finalChunk) return; - expect(finalChunk).toBeDefined(); - if (!finalChunk) return; + const toolCalls = finalChunk.tool_calls; + expect(toolCalls).toBeDefined(); + if (!toolCalls) { + throw new Error("tool_calls not in response"); + } + expect(toolCalls.length).toBe(1); + expect(toolCalls[0].name).toBe("current_weather_tool"); + expect(toolCalls[0].args).toHaveProperty("location"); + }); - const toolCalls = finalChunk.tool_calls; - expect(toolCalls).toBeDefined(); - if (!toolCalls) { - throw new Error("tool_calls not in response"); + async function fileToBase64(filePath: string): Promise { + const fileData = await fs.readFile(filePath); + const base64String = Buffer.from(fileData).toString("base64"); + return base64String; } - expect(toolCalls.length).toBe(1); - expect(toolCalls[0].name).toBe("current_weather_tool"); - expect(toolCalls[0].args).toHaveProperty("location"); - }); - async function fileToBase64(filePath: string): Promise { - const fileData = await fs.readFile(filePath); - const base64String = Buffer.from(fileData).toString("base64"); - return base64String; - } + test("Gemini can understand audio", async () => { + // Update this with the correct path to an audio file on your machine. + const audioPath = + "../langchain-google-genai/src/tests/data/gettysburg10.wav"; + const audioMimeType = "audio/wav"; - test("Gemini can understand audio", async () => { - // Update this with the correct path to an audio file on your machine. - const audioPath = - "../langchain-google-genai/src/tests/data/gettysburg10.wav"; - const audioMimeType = "audio/wav"; + const model = newChatGoogle({ + temperature: 0, + maxRetries: 0, + }); - const model = newChatGoogle({ - temperature: 0, - maxRetries: 0, - }); + const audioBase64 = await fileToBase64(audioPath); - const audioBase64 = await fileToBase64(audioPath); + const prompt = ChatPromptTemplate.fromMessages([ + new MessagesPlaceholder("audio"), + ]); - const prompt = ChatPromptTemplate.fromMessages([ - new MessagesPlaceholder("audio"), - ]); + const chain = prompt.pipe(model); + const response = await chain.invoke({ + audio: new HumanMessage({ + content: [ + { + type: "media", + mimeType: audioMimeType, + data: audioBase64, + }, + { + type: "text", + text: "Summarize the content in this audio. ALso, what is the speaker's tone?", + }, + ], + }), + }); - const chain = prompt.pipe(model); - const response = await chain.invoke({ - audio: new HumanMessage({ - content: [ - { - type: "media", - mimeType: audioMimeType, - data: audioBase64, - }, - { - type: "text", - text: "Summarize the content in this audio. ALso, what is the speaker's tone?", - }, - ], - }), + expect(typeof response.content).toBe("string"); + expect((response.content as string).length).toBeGreaterThan(15); }); - expect(typeof response.content).toBe("string"); - expect((response.content as string).length).toBeGreaterThan(15); - }); - - test("Supports GoogleSearchRetrievalTool", async () => { - const searchRetrievalTool = { - googleSearchRetrieval: { - dynamicRetrievalConfig: { - mode: "MODE_DYNAMIC", - dynamicThreshold: 0.7, // default is 0.7 + test("Supports GoogleSearchRetrievalTool", async () => { + const searchRetrievalTool = { + googleSearchRetrieval: { + dynamicRetrievalConfig: { + mode: "MODE_DYNAMIC", + dynamicThreshold: 0.7, // default is 0.7 + }, }, - }, - }; - const model = newChatGoogle({ - temperature: 0, - maxRetries: 0, - }).bindTools([searchRetrievalTool]); - - const result = await model.invoke("Who won the 2024 MLB World Series?"); - expect(result.content as string).toContain("Dodgers"); - expect(result).toHaveProperty("response_metadata"); - expect(result.response_metadata).toHaveProperty("groundingMetadata"); - expect(result.response_metadata).toHaveProperty("groundingSupport"); - }); + }; + const model = newChatGoogle({ + temperature: 0, + maxRetries: 0, + }).bindTools([searchRetrievalTool]); + + const result = await model.invoke("Who won the 2024 MLB World Series?"); + expect(result.content as string).toContain("Dodgers"); + expect(result).toHaveProperty("response_metadata"); + expect(result.response_metadata).toHaveProperty("groundingMetadata"); + expect(result.response_metadata).toHaveProperty("groundingSupport"); + }); - test("Supports GoogleSearchTool", async () => { - const searchTool: GeminiTool = { - googleSearch: { - }, - }; - const model = newChatGoogle({ - temperature: 0, - maxRetries: 0, - }).bindTools([searchTool]); - - const result = await model.invoke("Who won the 2024 MLB World Series?"); - expect(result.content as string).toContain("Dodgers"); - expect(result).toHaveProperty("response_metadata"); - expect(result.response_metadata).toHaveProperty("groundingMetadata"); - expect(result.response_metadata).toHaveProperty("groundingSupport"); - }); + test("Supports GoogleSearchTool", async () => { + const searchTool: GeminiTool = { + googleSearch: {}, + }; + const model = newChatGoogle({ + temperature: 0, + maxRetries: 0, + }).bindTools([searchTool]); + + const result = await model.invoke("Who won the 2024 MLB World Series?"); + expect(result.content as string).toContain("Dodgers"); + expect(result).toHaveProperty("response_metadata"); + expect(result.response_metadata).toHaveProperty("groundingMetadata"); + expect(result.response_metadata).toHaveProperty("groundingSupport"); + }); - test("Can stream GoogleSearchRetrievalTool", async () => { - const searchRetrievalTool = { - googleSearchRetrieval: { - dynamicRetrievalConfig: { - mode: "MODE_DYNAMIC", - dynamicThreshold: 0.7, // default is 0.7 + test("Can stream GoogleSearchRetrievalTool", async () => { + const searchRetrievalTool = { + googleSearchRetrieval: { + dynamicRetrievalConfig: { + mode: "MODE_DYNAMIC", + dynamicThreshold: 0.7, // default is 0.7 + }, }, - }, - }; - const model = newChatGoogle({ - temperature: 0, - maxRetries: 0, - }).bindTools([searchRetrievalTool]); - - const stream = await model.stream("Who won the 2024 MLB World Series?"); - let finalMsg: AIMessageChunk | undefined; - for await (const msg of stream) { - finalMsg = finalMsg ? concat(finalMsg, msg) : msg; - } - if (!finalMsg) { - throw new Error("finalMsg is undefined"); - } - expect(finalMsg.content as string).toContain("Dodgers"); - }); - -}); + }; + const model = newChatGoogle({ + temperature: 0, + maxRetries: 0, + }).bindTools([searchRetrievalTool]); + + const stream = await model.stream("Who won the 2024 MLB World Series?"); + let finalMsg: AIMessageChunk | undefined; + for await (const msg of stream) { + finalMsg = finalMsg ? concat(finalMsg, msg) : msg; + } + if (!finalMsg) { + throw new Error("finalMsg is undefined"); + } + expect(finalMsg.content as string).toContain("Dodgers"); + }); + } +); From de56086b11926d1af98a7c1186e2bb9449219c89 Mon Sep 17 00:00:00 2001 From: jacoblee93 Date: Mon, 30 Dec 2024 09:26:32 -0800 Subject: [PATCH 7/8] Use a type guard --- libs/langchain-google-common/src/utils/common.ts | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/libs/langchain-google-common/src/utils/common.ts b/libs/langchain-google-common/src/utils/common.ts index ea941082e2db..4194f9578b9e 100644 --- a/libs/langchain-google-common/src/utils/common.ts +++ b/libs/langchain-google-common/src/utils/common.ts @@ -62,7 +62,7 @@ function processToolChoice( throw new Error("Object inputs for tool_choice not supported."); } -function isGeminiTool(tool: GoogleAIToolType): boolean { +function isGeminiTool(tool: GoogleAIToolType): tool is GeminiTool { for (const toolAttribute of GeminiToolAttributes) { if (toolAttribute in tool) { return true; @@ -71,7 +71,7 @@ function isGeminiTool(tool: GoogleAIToolType): boolean { return false; } -function isGeminiNonFunctionTool(tool: GoogleAIToolType): boolean { +function isGeminiNonFunctionTool(tool: GoogleAIToolType): tool is GeminiTool { return isGeminiTool(tool) && !("functionDeclaration" in tool); } @@ -80,7 +80,7 @@ export function convertToGeminiTools(tools: GoogleAIToolType[]): GeminiTool[] { let functionDeclarationsIndex = -1; tools.forEach((tool) => { if (isGeminiNonFunctionTool(tool)) { - geminiTools.push(tool as GeminiTool); + geminiTools.push(tool); } else { if (functionDeclarationsIndex === -1) { geminiTools.push({ From 16594008ac2dd81c0e79cd533a88468bdfd3714e Mon Sep 17 00:00:00 2001 From: jacoblee93 Date: Mon, 30 Dec 2024 09:48:20 -0800 Subject: [PATCH 8/8] Fix build issue caused by import in test --- .../src/tests/chat_models.int.test.ts | 146 +++++++++--------- 1 file changed, 70 insertions(+), 76 deletions(-) diff --git a/libs/langchain-google-webauth/src/tests/chat_models.int.test.ts b/libs/langchain-google-webauth/src/tests/chat_models.int.test.ts index 5379533c7181..e66bab6f06ca 100644 --- a/libs/langchain-google-webauth/src/tests/chat_models.int.test.ts +++ b/libs/langchain-google-webauth/src/tests/chat_models.int.test.ts @@ -17,10 +17,7 @@ import { import { BaseLanguageModelInput } from "@langchain/core/language_models/base"; import { ChatPromptValue } from "@langchain/core/prompt_values"; import { - BackedBlobStore, - MediaBlob, MediaManager, - ReadThroughBlobStore, SimpleWebBlobStore, } from "@langchain/google-common/experimental/utils/media_core"; import { @@ -29,9 +26,6 @@ import { GoogleRequestRecorder, } from "@langchain/google-common"; import { BaseCallbackHandler } from "@langchain/core/callbacks/base"; -import { InMemoryStore } from "@langchain/core/stores"; -import { BlobStoreGoogleCloudStorage } from "@langchain/google-gauth"; -import { GoogleCloudStorageUri } from "@langchain/google-common/experimental/media"; import { concat } from "@langchain/core/utils/stream"; import fs from "fs/promises"; import { @@ -535,76 +529,76 @@ describe.each(testGeminiModelNames)( expect(result).toHaveProperty("location"); }); - test("media - fileData", async () => { - class MemStore extends InMemoryStore { - get length() { - return Object.keys(this.store).length; - } - } - const aliasMemory = new MemStore(); - const aliasStore = new BackedBlobStore({ - backingStore: aliasMemory, - defaultFetchOptions: { - actionIfBlobMissing: undefined, - }, - }); - const backingStore = new BlobStoreGoogleCloudStorage({ - uriPrefix: new GoogleCloudStorageUri( - "gs://test-langchainjs/mediatest/" - ), - defaultStoreOptions: { - actionIfInvalid: "prefixPath", - }, - }); - const blobStore = new ReadThroughBlobStore({ - baseStore: aliasStore, - backingStore, - }); - const resolver = new SimpleWebBlobStore(); - const mediaManager = new MediaManager({ - store: blobStore, - resolvers: [resolver], - }); - const model = newChatGoogle({ - apiConfig: { - mediaManager, - }, - }); - - const message: MessageContentComplex[] = [ - { - type: "text", - text: "What is in this image?", - }, - { - type: "media", - fileUri: "https://js.langchain.com/v0.2/img/brand/wordmark.png", - }, - ]; - - const messages: BaseMessage[] = [ - new HumanMessageChunk({ content: message }), - ]; - - try { - const res = await model.invoke(messages); - - console.log(res); - - expect(res).toBeDefined(); - expect(res._getType()).toEqual("ai"); - - const aiMessage = res as AIMessageChunk; - expect(aiMessage.content).toBeDefined(); - - expect(typeof aiMessage.content).toBe("string"); - const text = aiMessage.content as string; - expect(text).toMatch(/LangChain/); - } catch (e) { - console.error(e); - throw e; - } - }); + // test("media - fileData", async () => { + // class MemStore extends InMemoryStore { + // get length() { + // return Object.keys(this.store).length; + // } + // } + // const aliasMemory = new MemStore(); + // const aliasStore = new BackedBlobStore({ + // backingStore: aliasMemory, + // defaultFetchOptions: { + // actionIfBlobMissing: undefined, + // }, + // }); + // const backingStore = new BlobStoreGoogleCloudStorage({ + // uriPrefix: new GoogleCloudStorageUri( + // "gs://test-langchainjs/mediatest/" + // ), + // defaultStoreOptions: { + // actionIfInvalid: "prefixPath", + // }, + // }); + // const blobStore = new ReadThroughBlobStore({ + // baseStore: aliasStore, + // backingStore, + // }); + // const resolver = new SimpleWebBlobStore(); + // const mediaManager = new MediaManager({ + // store: blobStore, + // resolvers: [resolver], + // }); + // const model = newChatGoogle({ + // apiConfig: { + // mediaManager, + // }, + // }); + + // const message: MessageContentComplex[] = [ + // { + // type: "text", + // text: "What is in this image?", + // }, + // { + // type: "media", + // fileUri: "https://js.langchain.com/v0.2/img/brand/wordmark.png", + // }, + // ]; + + // const messages: BaseMessage[] = [ + // new HumanMessageChunk({ content: message }), + // ]; + + // try { + // const res = await model.invoke(messages); + + // console.log(res); + + // expect(res).toBeDefined(); + // expect(res._getType()).toEqual("ai"); + + // const aiMessage = res as AIMessageChunk; + // expect(aiMessage.content).toBeDefined(); + + // expect(typeof aiMessage.content).toBe("string"); + // const text = aiMessage.content as string; + // expect(text).toMatch(/LangChain/); + // } catch (e) { + // console.error(e); + // throw e; + // } + // }); test("Stream token count usage_metadata", async () => { const model = newChatGoogle({