From 1a4e0ca4c5a34f5498832857e7c09d8289fe04be Mon Sep 17 00:00:00 2001 From: Jacob Lee Date: Thu, 19 Dec 2024 17:23:08 -0800 Subject: [PATCH] fix(core): Fix runnable with fallbacks stream events bug (#7411) --- langchain-core/src/runnables/base.ts | 49 +++++++++++-------- .../tests/runnable_with_fallbacks.test.ts | 37 ++++++++++++++ 2 files changed, 65 insertions(+), 21 deletions(-) diff --git a/langchain-core/src/runnables/base.ts b/langchain-core/src/runnables/base.ts index d2ba10b6f6bd..09c8c7162ee9 100644 --- a/langchain-core/src/runnables/base.ts +++ b/langchain-core/src/runnables/base.ts @@ -2840,7 +2840,7 @@ export class RunnableWithFallbacks extends Runnable< options?: Partial ): Promise { const config = ensureConfig(options); - const callbackManager_ = await getCallbackManagerForConfig(options); + const callbackManager_ = await getCallbackManagerForConfig(config); const { runId, ...otherConfigFields } = config; const runManager = await callbackManager_?.handleChainStart( this.toJSON(), @@ -2851,27 +2851,33 @@ export class RunnableWithFallbacks extends Runnable< undefined, otherConfigFields?.runName ); - let firstError; - for (const runnable of this.runnables()) { - config?.signal?.throwIfAborted(); - try { - const output = await runnable.invoke( - input, - patchConfig(otherConfigFields, { callbacks: runManager?.getChild() }) - ); - await runManager?.handleChainEnd(_coerceToDict(output, "output")); - return output; - } catch (e) { + const childConfig = patchConfig(otherConfigFields, { + callbacks: runManager?.getChild(), + }); + const res = await AsyncLocalStorageProviderSingleton.runWithConfig( + childConfig, + async () => { + let firstError; + for (const runnable of this.runnables()) { + config?.signal?.throwIfAborted(); + try { + const output = await runnable.invoke(input, childConfig); + await runManager?.handleChainEnd(_coerceToDict(output, "output")); + return output; + } catch (e) { + if (firstError === undefined) { + firstError = e; + } + } + } if (firstError === undefined) { - firstError = e; + throw new Error("No error stored at end of fallback."); } + await runManager?.handleChainError(firstError); + throw firstError; } - } - if (firstError === undefined) { - throw new Error("No error stored at end of fallback."); - } - await runManager?.handleChainError(firstError); - throw firstError; + ); + return res; } async *_streamIterator( @@ -2879,7 +2885,7 @@ export class RunnableWithFallbacks extends Runnable< options?: Partial | undefined ): AsyncGenerator { const config = ensureConfig(options); - const callbackManager_ = await getCallbackManagerForConfig(options); + const callbackManager_ = await getCallbackManagerForConfig(config); const { runId, ...otherConfigFields } = config; const runManager = await callbackManager_?.handleChainStart( this.toJSON(), @@ -2898,7 +2904,8 @@ export class RunnableWithFallbacks extends Runnable< callbacks: runManager?.getChild(), }); try { - stream = await runnable.stream(input, childConfig); + const originalStream = await runnable.stream(input, childConfig); + stream = consumeAsyncIterableInContext(childConfig, originalStream); break; } catch (e) { if (firstError === undefined) { diff --git a/langchain-core/src/runnables/tests/runnable_with_fallbacks.test.ts b/langchain-core/src/runnables/tests/runnable_with_fallbacks.test.ts index 652116fd70ee..7cb5c5fe04c5 100644 --- a/langchain-core/src/runnables/tests/runnable_with_fallbacks.test.ts +++ b/langchain-core/src/runnables/tests/runnable_with_fallbacks.test.ts @@ -1,7 +1,11 @@ /* eslint-disable no-promise-executor-return */ /* eslint-disable @typescript-eslint/no-explicit-any */ +/* eslint-disable no-process-env */ import { test, expect } from "@jest/globals"; +import { AsyncLocalStorage } from "node:async_hooks"; import { FakeLLM, FakeStreamingLLM } from "../../utils/testing/index.js"; +import { RunnableLambda } from "../base.js"; +import { AsyncLocalStorageProviderSingleton } from "../../singletons/index.js"; test("RunnableWithFallbacks", async () => { const llm = new FakeLLM({ @@ -55,3 +59,36 @@ test("RunnableWithFallbacks stream", async () => { expect(chunks.length).toBeGreaterThan(1); expect(chunks.join("")).toEqual("What up"); }); + +test("RunnableWithFallbacks stream events with local storage and callbacks added via env vars", async () => { + process.env.LANGCHAIN_VERBOSE = "true"; + AsyncLocalStorageProviderSingleton.initializeGlobalInstance( + new AsyncLocalStorage() + ); + const llm = new FakeStreamingLLM({ + thrownErrorString: "Bad error!", + }); + const llmWithFallbacks = llm.withFallbacks({ + fallbacks: [new FakeStreamingLLM({})], + }); + const runnable = RunnableLambda.from(async (input: any) => { + const res = await llmWithFallbacks.invoke(input); + const stream = await llmWithFallbacks.stream(input); + for await (const _ of stream) { + void _; + } + return res; + }); + const stream = await runnable.streamEvents("hi", { + version: "v2", + }); + const chunks = []; + for await (const chunk of stream) { + if (chunk.event === "on_llm_stream") { + chunks.push(chunk); + } + } + expect(chunks.length).toBeGreaterThan(1); + console.log(JSON.stringify(chunks, null, 2)); + expect(chunks.map((chunk) => chunk.data.chunk.text).join("")).toEqual("hihi"); +});