From e30779b7a784c4ef3d465f54c5b01af9dccdd355 Mon Sep 17 00:00:00 2001 From: Jacob Lee Date: Mon, 16 Dec 2024 11:13:01 -0800 Subject: [PATCH] feat(core): Generalize streaming usage for language models based on passed callback handlers (#7378) --- langchain-core/src/callbacks/base.ts | 13 +++++++++++++ langchain-core/src/language_models/chat_models.ts | 9 ++++----- langchain-core/src/language_models/llms.ts | 9 ++++----- langchain-core/src/tracers/event_stream.ts | 8 +++++++- langchain-core/src/tracers/log_stream.ts | 8 +++++++- 5 files changed, 35 insertions(+), 12 deletions(-) diff --git a/langchain-core/src/callbacks/base.ts b/langchain-core/src/callbacks/base.ts index 42ecd9820768..923dde568093 100644 --- a/langchain-core/src/callbacks/base.ts +++ b/langchain-core/src/callbacks/base.ts @@ -285,6 +285,19 @@ abstract class BaseCallbackHandlerMethodsClass { */ export type CallbackHandlerMethods = BaseCallbackHandlerMethodsClass; +/** + * Interface for handlers that can indicate a preference for streaming responses. + * When implemented, this allows the handler to signal whether it prefers to receive + * streaming responses from language models rather than complete responses. + */ +export interface CallbackHandlerPrefersStreaming { + readonly lc_prefer_streaming: boolean; +} + +export function callbackHandlerPrefersStreaming(x: BaseCallbackHandler) { + return "lc_prefer_streaming" in x && x.lc_prefer_streaming; +} + /** * Abstract base class for creating callback handlers in the LangChain * framework. It provides a set of optional methods that can be overridden diff --git a/langchain-core/src/language_models/chat_models.ts b/langchain-core/src/language_models/chat_models.ts index 0878d67def8e..e80e5cf90886 100644 --- a/langchain-core/src/language_models/chat_models.ts +++ b/langchain-core/src/language_models/chat_models.ts @@ -44,11 +44,10 @@ import { RunnableSequence, RunnableToolLike, } from "../runnables/base.js"; -import { isStreamEventsHandler } from "../tracers/event_stream.js"; -import { isLogStreamHandler } from "../tracers/log_stream.js"; import { concat } from "../utils/stream.js"; import { RunnablePassthrough } from "../runnables/passthrough.js"; import { isZodSchema } from "../utils/types/is_zod_schema.js"; +import { callbackHandlerPrefersStreaming } from "../callbacks/base.js"; // eslint-disable-next-line @typescript-eslint/no-explicit-any export type ToolChoice = string | Record | "auto" | "any"; @@ -370,9 +369,9 @@ export abstract class BaseChatModel< // Even if stream is not explicitly called, check if model is implicitly // called from streamEvents() or streamLog() to get all streamed events. // Bail out if _streamResponseChunks not overridden - const hasStreamingHandler = !!runManagers?.[0].handlers.find((handler) => { - return isStreamEventsHandler(handler) || isLogStreamHandler(handler); - }); + const hasStreamingHandler = !!runManagers?.[0].handlers.find( + callbackHandlerPrefersStreaming + ); if ( hasStreamingHandler && baseMessages.length === 1 && diff --git a/langchain-core/src/language_models/llms.ts b/langchain-core/src/language_models/llms.ts index 3aeb2a879bdc..ce75a52479be 100644 --- a/langchain-core/src/language_models/llms.ts +++ b/langchain-core/src/language_models/llms.ts @@ -24,9 +24,8 @@ import { } from "./base.js"; import type { RunnableConfig } from "../runnables/config.js"; import type { BaseCache } from "../caches/base.js"; -import { isStreamEventsHandler } from "../tracers/event_stream.js"; -import { isLogStreamHandler } from "../tracers/log_stream.js"; import { concat } from "../utils/stream.js"; +import { callbackHandlerPrefersStreaming } from "../callbacks/base.js"; export type SerializedLLM = { _model: string; @@ -270,9 +269,9 @@ export abstract class BaseLLM< // Even if stream is not explicitly called, check if model is implicitly // called from streamEvents() or streamLog() to get all streamed events. // Bail out if _streamResponseChunks not overridden - const hasStreamingHandler = !!runManagers?.[0].handlers.find((handler) => { - return isStreamEventsHandler(handler) || isLogStreamHandler(handler); - }); + const hasStreamingHandler = !!runManagers?.[0].handlers.find( + callbackHandlerPrefersStreaming + ); let output: LLMResult; if ( hasStreamingHandler && diff --git a/langchain-core/src/tracers/event_stream.ts b/langchain-core/src/tracers/event_stream.ts index 3972e7ce9b4b..54a543a0069a 100644 --- a/langchain-core/src/tracers/event_stream.ts +++ b/langchain-core/src/tracers/event_stream.ts @@ -2,6 +2,7 @@ import { BaseTracer, type Run } from "./base.js"; import { BaseCallbackHandler, BaseCallbackHandlerInput, + CallbackHandlerPrefersStreaming, } from "../callbacks/base.js"; import { IterableReadableStream } from "../utils/stream.js"; import { AIMessageChunk } from "../messages/ai.js"; @@ -145,7 +146,10 @@ export const isStreamEventsHandler = ( * handler that logs the execution of runs and emits `RunLog` instances to a * `RunLogStream`. */ -export class EventStreamCallbackHandler extends BaseTracer { +export class EventStreamCallbackHandler + extends BaseTracer + implements CallbackHandlerPrefersStreaming +{ protected autoClose = true; protected includeNames?: string[]; @@ -172,6 +176,8 @@ export class EventStreamCallbackHandler extends BaseTracer { name = "event_stream_tracer"; + lc_prefer_streaming = true; + constructor(fields?: EventStreamCallbackHandlerInput) { super({ _awaitHandler: true, ...fields }); this.autoClose = fields?.autoClose ?? true; diff --git a/langchain-core/src/tracers/log_stream.ts b/langchain-core/src/tracers/log_stream.ts index 13b97349b04b..f13647f0dc4c 100644 --- a/langchain-core/src/tracers/log_stream.ts +++ b/langchain-core/src/tracers/log_stream.ts @@ -6,6 +6,7 @@ import { BaseTracer, type Run } from "./base.js"; import { BaseCallbackHandler, BaseCallbackHandlerInput, + CallbackHandlerPrefersStreaming, HandleLLMNewTokenCallbackFields, } from "../callbacks/base.js"; import { IterableReadableStream } from "../utils/stream.js"; @@ -210,7 +211,10 @@ function isChatGenerationChunk( * handler that logs the execution of runs and emits `RunLog` instances to a * `RunLogStream`. */ -export class LogStreamCallbackHandler extends BaseTracer { +export class LogStreamCallbackHandler + extends BaseTracer + implements CallbackHandlerPrefersStreaming +{ protected autoClose = true; protected includeNames?: string[]; @@ -241,6 +245,8 @@ export class LogStreamCallbackHandler extends BaseTracer { name = "log_stream_tracer"; + lc_prefer_streaming = true; + constructor(fields?: LogStreamCallbackHandlerInput) { super({ _awaitHandler: true, ...fields }); this.autoClose = fields?.autoClose ?? true;