Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core[minor]: Make LLMs and chat models always stream when invoked within streamEvents #5604

Merged
merged 3 commits into from
May 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 81 additions & 38 deletions langchain-core/src/language_models/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ import type { RunnableConfig } from "../runnables/config.js";
import type { BaseCache } from "../caches.js";
import { StructuredToolInterface } from "../tools.js";
import { Runnable } from "../runnables/base.js";
import { isStreamEventsHandler } from "../tracers/event_stream.js";
import { isLogStreamHandler } from "../tracers/log_stream.js";
import { concat } from "../utils/stream.js";

/**
* Represents a serialized chat model.
Expand Down Expand Up @@ -306,48 +309,88 @@ export abstract class BaseChatModel<
undefined,
handledOptions.runName
);
// generate results
const results = await Promise.allSettled(
baseMessages.map((messageList, i) =>
this._generate(
messageList,
{ ...parsedOptions, promptIndex: i },
runManagers?.[i]
)
)
);
// handle results
const generations: ChatGeneration[][] = [];
const llmOutputs: LLMResult["llmOutput"][] = [];
await Promise.all(
results.map(async (pResult, i) => {
if (pResult.status === "fulfilled") {
const result = pResult.value;
for (const generation of result.generations) {
generation.message.response_metadata = {
...generation.generationInfo,
...generation.message.response_metadata,
};
}
if (result.generations.length === 1) {
result.generations[0].message.response_metadata = {
...result.llmOutput,
...result.generations[0].message.response_metadata,
};
// Even if stream is not explicitly called, check if model is implicitly
// called from streamEvents() or streamLog() to get all streamed events.
// Bail out if _streamResponseChunks not overridden
const hasStreamingHandler = !!runManagers?.[0].handlers.find((handler) => {
return isStreamEventsHandler(handler) || isLogStreamHandler(handler);
});
if (
hasStreamingHandler &&
baseMessages.length === 1 &&
this._streamResponseChunks !==
jacoblee93 marked this conversation as resolved.
Show resolved Hide resolved
BaseChatModel.prototype._streamResponseChunks
) {
try {
const stream = await this._streamResponseChunks(
baseMessages[0],
parsedOptions,
runManagers?.[0]
);
let aggregated;
for await (const chunk of stream) {
if (aggregated === undefined) {
aggregated = chunk;
} else {
aggregated = concat(aggregated, chunk);
}
generations[i] = result.generations;
llmOutputs[i] = result.llmOutput;
return runManagers?.[i]?.handleLLMEnd({
generations: [result.generations],
llmOutput: result.llmOutput,
});
} else {
// status === "rejected"
await runManagers?.[i]?.handleLLMError(pResult.reason);
return Promise.reject(pResult.reason);
}
})
);
if (aggregated === undefined) {
throw new Error("Received empty response from chat model call.");
}
generations.push([aggregated]);
await runManagers?.[0].handleLLMEnd({
generations,
llmOutput: {},
});
} catch (e) {
await runManagers?.[0].handleLLMError(e);
throw e;
}
} else {
// generate results
const results = await Promise.allSettled(
baseMessages.map((messageList, i) =>
this._generate(
messageList,
{ ...parsedOptions, promptIndex: i },
runManagers?.[i]
)
)
);
// handle results
await Promise.all(
results.map(async (pResult, i) => {
if (pResult.status === "fulfilled") {
const result = pResult.value;
for (const generation of result.generations) {
generation.message.response_metadata = {
...generation.generationInfo,
...generation.message.response_metadata,
};
}
if (result.generations.length === 1) {
result.generations[0].message.response_metadata = {
...result.llmOutput,
...result.generations[0].message.response_metadata,
};
}
generations[i] = result.generations;
llmOutputs[i] = result.llmOutput;
return runManagers?.[i]?.handleLLMEnd({
generations: [result.generations],
llmOutput: result.llmOutput,
});
} else {
// status === "rejected"
await runManagers?.[i]?.handleLLMError(pResult.reason);
return Promise.reject(pResult.reason);
}
})
);
}
// create combined output
const output: LLMResult = {
generations,
Expand Down
66 changes: 53 additions & 13 deletions langchain-core/src/language_models/llms.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ import {
} from "./base.js";
import type { RunnableConfig } from "../runnables/config.js";
import type { BaseCache } from "../caches.js";
import { isStreamEventsHandler } from "../tracers/event_stream.js";
import { isLogStreamHandler } from "../tracers/log_stream.js";
import { concat } from "../utils/stream.js";

export type SerializedLLM = {
_model: string;
Expand Down Expand Up @@ -276,23 +279,60 @@ export abstract class BaseLLM<
undefined,
handledOptions?.runName
);
// Even if stream is not explicitly called, check if model is implicitly
// called from streamEvents() or streamLog() to get all streamed events.
// Bail out if _streamResponseChunks not overridden
const hasStreamingHandler = !!runManagers?.[0].handlers.find((handler) => {
return isStreamEventsHandler(handler) || isLogStreamHandler(handler);
});
let output: LLMResult;
if (
hasStreamingHandler &&
prompts.length === 1 &&
this._streamResponseChunks !== BaseLLM.prototype._streamResponseChunks
) {
try {
const stream = await this._streamResponseChunks(
prompts[0],
parsedOptions,
runManagers?.[0]
);
let aggregated;
for await (const chunk of stream) {
if (aggregated === undefined) {
aggregated = chunk;
} else {
aggregated = concat(aggregated, chunk);
}
}
if (aggregated === undefined) {
throw new Error("Received empty response from chat model call.");
}
output = { generations: [[aggregated]], llmOutput: {} };
await runManagers?.[0].handleLLMEnd(output);
} catch (e) {
await runManagers?.[0].handleLLMError(e);
throw e;
}
} else {
try {
output = await this._generate(prompts, parsedOptions, runManagers?.[0]);
} catch (err) {
await Promise.all(
(runManagers ?? []).map((runManager) =>
runManager?.handleLLMError(err)
)
);
throw err;
}

let output;
try {
output = await this._generate(prompts, parsedOptions, runManagers?.[0]);
} catch (err) {
const flattenedOutputs: LLMResult[] = this._flattenLLMResult(output);
await Promise.all(
(runManagers ?? []).map((runManager) => runManager?.handleLLMError(err))
(runManagers ?? []).map((runManager, i) =>
runManager?.handleLLMEnd(flattenedOutputs[i])
)
);
throw err;
}

const flattenedOutputs: LLMResult[] = this._flattenLLMResult(output);
await Promise.all(
(runManagers ?? []).map((runManager, i) =>
runManager?.handleLLMEnd(flattenedOutputs[i])
)
);
const runIds = runManagers?.map((manager) => manager.runId) || undefined;
// This defines RUN_KEY as a non-enumerable property on the output object
// so that it is not serialized when the output is stringified, and so that
Expand Down
11 changes: 2 additions & 9 deletions langchain-core/src/runnables/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@ import {
LogStreamCallbackHandlerInput,
RunLog,
RunLogPatch,
isLogStreamHandler,
} from "../tracers/log_stream.js";
import {
EventStreamCallbackHandler,
EventStreamCallbackHandlerInput,
StreamEvent,
StreamEventData,
isStreamEventsHandler,
} from "../tracers/event_stream.js";
import { Serializable } from "../load/serializable.js";
import {
Expand All @@ -38,7 +40,6 @@ import {
import { AsyncCaller } from "../utils/async_caller.js";
import { Run } from "../tracers/base.js";
import { RootListenersTracer } from "../tracers/root_listener.js";
import { BaseCallbackHandler } from "../callbacks/base.js";
import { _RootEventFilter, isRunnableInterface } from "./utils.js";
import { AsyncLocalStorageProviderSingleton } from "../singletons/index.js";
import { Graph } from "./graph.js";
Expand Down Expand Up @@ -503,10 +504,6 @@ export abstract class Runnable<
delete config.runId;
runManager = pipe.setup;

const isStreamEventsHandler = (
handler: BaseCallbackHandler
): handler is EventStreamCallbackHandler =>
handler.name === "event_stream_tracer";
const streamEventsHandler = runManager?.handlers.find(
isStreamEventsHandler
);
Expand All @@ -518,10 +515,6 @@ export abstract class Runnable<
);
}

const isLogStreamHandler = (
handler: BaseCallbackHandler
): handler is LogStreamCallbackHandler =>
handler.name === "log_stream_tracer";
const streamLogHandler = runManager?.handlers.find(isLogStreamHandler);
if (streamLogHandler !== undefined && runManager !== undefined) {
iterator = streamLogHandler.tapOutputIterable(
Expand Down
Loading
Loading