Skip to content

Commit

Permalink
google-common[patch]: Add streaming constructor param (#6165)
Browse files Browse the repository at this point in the history
* google-common[patch]: Add streaming constructor param

* fix callbacks

* chore: lint files
  • Loading branch information
bracesproul authored Jul 21, 2024
1 parent a6be453 commit 6557b39
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 2 deletions.
24 changes: 22 additions & 2 deletions libs/langchain-google-common/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import { BaseLLMOutputParser } from "@langchain/core/output_parsers";
import { isStructuredTool } from "@langchain/core/utils/function_calling";
import { AsyncCaller } from "@langchain/core/utils/async_caller";
import { StructuredToolInterface } from "@langchain/core/tools";
import { concat } from "@langchain/core/utils/stream";
import {
GoogleAIBaseLLMInput,
GoogleAIModelParams,
Expand Down Expand Up @@ -241,6 +242,8 @@ export abstract class ChatGoogleBase<AuthOptions>

streamUsage = true;

streaming = false;

protected connection: ChatConnection<AuthOptions>;

protected streamedConnection: ChatConnection<AuthOptions>;
Expand Down Expand Up @@ -351,22 +354,38 @@ export abstract class ChatGoogleBase<AuthOptions>
async _generate(
messages: BaseMessage[],
options: this["ParsedCallOptions"],
_runManager: CallbackManagerForLLMRun | undefined
runManager: CallbackManagerForLLMRun | undefined
): Promise<ChatResult> {
const parameters = this.invocationParams(options);

if (this.streaming) {
const stream = this._streamResponseChunks(messages, options, runManager);
let finalChunk: ChatGenerationChunk | null = null;
for await (const chunk of stream) {
finalChunk = !finalChunk ? chunk : concat(finalChunk, chunk);
}
if (!finalChunk) {
throw new Error("No chunks were returned from the stream.");
}
return {
generations: [finalChunk],
};
}

const response = await this.connection.request(
messages,
parameters,
options
);
const ret = safeResponseToChatResult(response, this.safetyHandler);
await runManager?.handleLLMNewToken(ret.generations[0].text);
return ret;
}

async *_streamResponseChunks(
_messages: BaseMessage[],
options: this["ParsedCallOptions"],
_runManager?: CallbackManagerForLLMRun
runManager?: CallbackManagerForLLMRun
): AsyncGenerator<ChatGenerationChunk> {
// Make the call as a streaming request
const parameters = this.invocationParams(options);
Expand Down Expand Up @@ -408,6 +427,7 @@ export abstract class ChatGoogleBase<AuthOptions>
}),
});
yield chunk;
await runManager?.handleLLMNewToken(chunk.text);
}
}

Expand Down
11 changes: 11 additions & 0 deletions libs/langchain-google-common/src/tests/chat_models.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -914,3 +914,14 @@ test("removeAdditionalProperties can remove all instances of additionalPropertie
analysisSchemaObj.find((key) => key === "additionalProperties")
).toBeUndefined();
});

test("Can set streaming param", () => {
const modelWithStreamingDefault = new ChatGoogle();

expect(modelWithStreamingDefault.streaming).toBe(false);

const modelWithStreamingTrue = new ChatGoogle({
streaming: true,
});
expect(modelWithStreamingTrue.streaming).toBe(true);
});
6 changes: 6 additions & 0 deletions libs/langchain-google-common/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,12 @@ export interface GoogleAIModelParams {
* @default "text/plain"
*/
responseMimeType?: GoogleAIResponseMimeType;

/**
* Whether or not to stream.
* @default false
*/
streaming?: boolean;
}

/**
Expand Down
1 change: 1 addition & 0 deletions libs/langchain-google-common/src/utils/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ export function copyAIModelParamsInto(
options?.responseMimeType ??
params?.responseMimeType ??
target?.responseMimeType;
ret.streaming = options?.streaming ?? params?.streaming ?? target?.streaming;

ret.tools = options?.tools;
// Ensure tools are formatted properly for Gemini
Expand Down
25 changes: 25 additions & 0 deletions libs/langchain-google-vertexai/src/tests/chat_models.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -297,3 +297,28 @@ test("Invoke token count usage_metadata", async () => {
res.usage_metadata.input_tokens + res.usage_metadata.output_tokens
);
});

test("Streaming true constructor param will stream", async () => {
const modelWithStreaming = new ChatVertexAI({
maxOutputTokens: 50,
streaming: true,
});

let totalTokenCount = 0;
let tokensString = "";
const result = await modelWithStreaming.invoke("What is 1 + 1?", {
callbacks: [
{
handleLLMNewToken: (tok) => {
totalTokenCount += 1;
tokensString += tok;
},
},
],
});

expect(result).toBeDefined();
expect(result.content).toBe(tokensString);

expect(totalTokenCount).toBeGreaterThan(1);
});

0 comments on commit 6557b39

Please sign in to comment.