Skip to content

Commit

Permalink
fix(google-vertexai): fix bug when not using logprobs (#7515)
Browse files Browse the repository at this point in the history
  • Loading branch information
afirstenberg authored Jan 13, 2025
1 parent de63626 commit 457c8f2
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 5 deletions.
2 changes: 1 addition & 1 deletion libs/langchain-google-common/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ export abstract class ChatGoogleBase<AuthOptions>

stopSequences: string[] = [];

logprobs: boolean = false;
logprobs: boolean;

topLogprobs: number = 0;

Expand Down
75 changes: 74 additions & 1 deletion libs/langchain-google-common/src/tests/chat_models.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1251,7 +1251,80 @@ describe("Mock ChatGoogle - Gemini", () => {
expect(record.opts.data.tools[0]).toHaveProperty("googleSearch");
});

test("7. logprobs", async () => {
test("7. logprobs request true", async () => {
const record: Record<string, any> = {};
const projectId = mockId();
const authOptions: MockClientAuthInfo = {
record,
projectId,
resultFile: "chat-7-mock.json",
};

const model = new ChatGoogle({
authOptions,
modelName: "gemini-1.5-flash-002",
logprobs: true,
topLogprobs: 5,
});
const result = await model.invoke(
"What are some names for a company that makes fancy socks?"
);
expect(result).toBeDefined();
const data = record?.opts?.data;
expect(data).toBeDefined();
expect(data.generationConfig.responseLogprobs).toEqual(true);
expect(data.generationConfig.logprobs).toEqual(5);
});

test("7. logprobs request false", async () => {
const record: Record<string, any> = {};
const projectId = mockId();
const authOptions: MockClientAuthInfo = {
record,
projectId,
resultFile: "chat-7-mock.json",
};

const model = new ChatGoogle({
authOptions,
modelName: "gemini-1.5-flash-002",
logprobs: false,
topLogprobs: 5,
});
const result = await model.invoke(
"What are some names for a company that makes fancy socks?"
);
expect(result).toBeDefined();
const data = record?.opts?.data;
expect(data).toBeDefined();
expect(data.generationConfig.responseLogprobs).toEqual(false);
expect(data.generationConfig.logprobs).not.toBeDefined();
});

test("7. logprobs request not defined", async () => {
const record: Record<string, any> = {};
const projectId = mockId();
const authOptions: MockClientAuthInfo = {
record,
projectId,
resultFile: "chat-7-mock.json",
};

const model = new ChatGoogle({
authOptions,
modelName: "gemini-1.5-flash-002",
});
const result = await model.invoke(
"What are some names for a company that makes fancy socks?"
);
expect(result).toBeDefined();
const data = record?.opts?.data;
expect(data).toBeDefined();
expect(data.generationConfig.responseLogprobs).not.toBeDefined();
expect(data.generationConfig.logprobs).not.toBeDefined();
});

test("7. logprobs response", async () => {
const record: Record<string, any> = {};
const projectId = mockId();
const authOptions: MockClientAuthInfo = {
Expand Down
17 changes: 14 additions & 3 deletions libs/langchain-google-common/src/utils/gemini.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1080,7 +1080,7 @@ export function getGeminiAPI(config?: GeminiAPIConfig): GoogleAIAPI {
function formatGenerationConfig(
parameters: GoogleAIModelRequestParams
): GeminiGenerationConfig {
return {
const ret: GeminiGenerationConfig = {
temperature: parameters.temperature,
topK: parameters.topK,
topP: parameters.topP,
Expand All @@ -1089,9 +1089,20 @@ export function getGeminiAPI(config?: GeminiAPIConfig): GoogleAIAPI {
maxOutputTokens: parameters.maxOutputTokens,
stopSequences: parameters.stopSequences,
responseMimeType: parameters.responseMimeType,
responseLogprobs: parameters.logprobs,
logprobs: parameters.topLogprobs,
};

// Add the logprobs if explicitly set
if (typeof parameters.logprobs !== "undefined") {
ret.responseLogprobs = parameters.logprobs;
if (
parameters.logprobs &&
typeof parameters.topLogprobs !== "undefined"
) {
ret.logprobs = parameters.topLogprobs;
}
}

return ret;
}

function formatSafetySettings(
Expand Down

0 comments on commit 457c8f2

Please sign in to comment.