Skip to content

Commit

Permalink
fix callbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
bracesproul committed Jul 21, 2024
1 parent 77317ed commit 4406c98
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 2 deletions.
22 changes: 20 additions & 2 deletions libs/langchain-google-common/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ import {
jsonSchemaToGeminiParameters,
zodToGeminiParameters,
} from "./utils/zod_to_gemini_parameters.js";
import { concat } from "@langchain/core/utils/stream";

class ChatConnection<AuthOptions> extends AbstractGoogleLLMConnection<
BaseMessage[],
Expand Down Expand Up @@ -353,22 +354,38 @@ export abstract class ChatGoogleBase<AuthOptions>
async _generate(
messages: BaseMessage[],
options: this["ParsedCallOptions"],
_runManager: CallbackManagerForLLMRun | undefined
runManager: CallbackManagerForLLMRun | undefined
): Promise<ChatResult> {
const parameters = this.invocationParams(options);

if (this.streaming) {
const stream = this._streamResponseChunks(messages, options, runManager);
let finalChunk: ChatGenerationChunk | null = null;
for await (const chunk of stream) {
finalChunk = !finalChunk ? chunk : concat(finalChunk, chunk);
}
if (!finalChunk) {
throw new Error("No chunks were returned from the stream.");
}
return {
generations: [finalChunk]
};
}

const response = await this.connection.request(
messages,
parameters,
options
);
const ret = safeResponseToChatResult(response, this.safetyHandler);
await runManager?.handleLLMNewToken(ret.generations[0].text);
return ret;
}

async *_streamResponseChunks(
_messages: BaseMessage[],
options: this["ParsedCallOptions"],
_runManager?: CallbackManagerForLLMRun
runManager?: CallbackManagerForLLMRun
): AsyncGenerator<ChatGenerationChunk> {
// Make the call as a streaming request
const parameters = this.invocationParams(options);
Expand Down Expand Up @@ -410,6 +427,7 @@ export abstract class ChatGoogleBase<AuthOptions>
}),
});
yield chunk;
await runManager?.handleLLMNewToken(chunk.text);
}
}

Expand Down
24 changes: 24 additions & 0 deletions libs/langchain-google-vertexai/src/tests/chat_models.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -297,3 +297,27 @@ test("Invoke token count usage_metadata", async () => {
res.usage_metadata.input_tokens + res.usage_metadata.output_tokens
);
});

test("Streaming true constructor param will stream", async () => {
const modelWithStreaming = new ChatVertexAI({
maxOutputTokens: 50,
streaming: true,
});

let totalTokenCount = 0;
let tokensString = "";
const result = await modelWithStreaming.invoke("What is 1 + 1?", {
callbacks: [{
handleLLMNewToken: (tok) => {
totalTokenCount += 1;
tokensString += tok;
}
}]
});

expect(result).toBeDefined();
console.log("result.content", result.content)
expect(result.content).toBe(tokensString);

expect(totalTokenCount).toBeGreaterThan(1);
})

0 comments on commit 4406c98

Please sign in to comment.