From 457c8f2ebbf33447b704fa6808843b3a40847994 Mon Sep 17 00:00:00 2001 From: Allen Firstenberg Date: Mon, 13 Jan 2025 17:23:05 -0500 Subject: [PATCH] fix(google-vertexai): fix bug when not using logprobs (#7515) --- .../src/chat_models.ts | 2 +- .../src/tests/chat_models.test.ts | 75 ++++++++++++++++++- .../src/utils/gemini.ts | 17 ++++- 3 files changed, 89 insertions(+), 5 deletions(-) diff --git a/libs/langchain-google-common/src/chat_models.ts b/libs/langchain-google-common/src/chat_models.ts index a345497c0d8f..75a83d50bf06 100644 --- a/libs/langchain-google-common/src/chat_models.ts +++ b/libs/langchain-google-common/src/chat_models.ts @@ -195,7 +195,7 @@ export abstract class ChatGoogleBase stopSequences: string[] = []; - logprobs: boolean = false; + logprobs: boolean; topLogprobs: number = 0; diff --git a/libs/langchain-google-common/src/tests/chat_models.test.ts b/libs/langchain-google-common/src/tests/chat_models.test.ts index 406e399c12b0..9494f55850eb 100644 --- a/libs/langchain-google-common/src/tests/chat_models.test.ts +++ b/libs/langchain-google-common/src/tests/chat_models.test.ts @@ -1251,7 +1251,80 @@ describe("Mock ChatGoogle - Gemini", () => { expect(record.opts.data.tools[0]).toHaveProperty("googleSearch"); }); - test("7. logprobs", async () => { + test("7. logprobs request true", async () => { + const record: Record = {}; + const projectId = mockId(); + const authOptions: MockClientAuthInfo = { + record, + projectId, + resultFile: "chat-7-mock.json", + }; + + const model = new ChatGoogle({ + authOptions, + modelName: "gemini-1.5-flash-002", + logprobs: true, + topLogprobs: 5, + }); + const result = await model.invoke( + "What are some names for a company that makes fancy socks?" + ); + expect(result).toBeDefined(); + const data = record?.opts?.data; + expect(data).toBeDefined(); + expect(data.generationConfig.responseLogprobs).toEqual(true); + expect(data.generationConfig.logprobs).toEqual(5); + }); + + test("7. logprobs request false", async () => { + const record: Record = {}; + const projectId = mockId(); + const authOptions: MockClientAuthInfo = { + record, + projectId, + resultFile: "chat-7-mock.json", + }; + + const model = new ChatGoogle({ + authOptions, + modelName: "gemini-1.5-flash-002", + logprobs: false, + topLogprobs: 5, + }); + const result = await model.invoke( + "What are some names for a company that makes fancy socks?" + ); + expect(result).toBeDefined(); + const data = record?.opts?.data; + expect(data).toBeDefined(); + expect(data.generationConfig.responseLogprobs).toEqual(false); + expect(data.generationConfig.logprobs).not.toBeDefined(); + }); + + test("7. logprobs request not defined", async () => { + const record: Record = {}; + const projectId = mockId(); + const authOptions: MockClientAuthInfo = { + record, + projectId, + resultFile: "chat-7-mock.json", + }; + + const model = new ChatGoogle({ + authOptions, + modelName: "gemini-1.5-flash-002", + }); + const result = await model.invoke( + "What are some names for a company that makes fancy socks?" + ); + expect(result).toBeDefined(); + const data = record?.opts?.data; + expect(data).toBeDefined(); + expect(data.generationConfig.responseLogprobs).not.toBeDefined(); + expect(data.generationConfig.logprobs).not.toBeDefined(); + }); + + test("7. logprobs response", async () => { const record: Record = {}; const projectId = mockId(); const authOptions: MockClientAuthInfo = { diff --git a/libs/langchain-google-common/src/utils/gemini.ts b/libs/langchain-google-common/src/utils/gemini.ts index 3408ab8e26b6..213160a43b10 100644 --- a/libs/langchain-google-common/src/utils/gemini.ts +++ b/libs/langchain-google-common/src/utils/gemini.ts @@ -1080,7 +1080,7 @@ export function getGeminiAPI(config?: GeminiAPIConfig): GoogleAIAPI { function formatGenerationConfig( parameters: GoogleAIModelRequestParams ): GeminiGenerationConfig { - return { + const ret: GeminiGenerationConfig = { temperature: parameters.temperature, topK: parameters.topK, topP: parameters.topP, @@ -1089,9 +1089,20 @@ export function getGeminiAPI(config?: GeminiAPIConfig): GoogleAIAPI { maxOutputTokens: parameters.maxOutputTokens, stopSequences: parameters.stopSequences, responseMimeType: parameters.responseMimeType, - responseLogprobs: parameters.logprobs, - logprobs: parameters.topLogprobs, }; + + // Add the logprobs if explicitly set + if (typeof parameters.logprobs !== "undefined") { + ret.responseLogprobs = parameters.logprobs; + if ( + parameters.logprobs && + typeof parameters.topLogprobs !== "undefined" + ) { + ret.logprobs = parameters.topLogprobs; + } + } + + return ret; } function formatSafetySettings(