From 3db4721b8233c1c49054fd885a5232b566bb09a6 Mon Sep 17 00:00:00 2001 From: jacoblee93 Date: Sat, 30 Mar 2024 02:41:08 -0700 Subject: [PATCH] Fix remote runnable tracing for stream log and stream events --- langchain-core/src/runnables/remote.ts | 174 ++++++++++-------- .../runnables/tests/runnable_remote.test.ts | 6 +- 2 files changed, 106 insertions(+), 74 deletions(-) diff --git a/langchain-core/src/runnables/remote.ts b/langchain-core/src/runnables/remote.ts index 5ea2dd158afd..306ef1e0e1ab 100644 --- a/langchain-core/src/runnables/remote.ts +++ b/langchain-core/src/runnables/remote.ts @@ -4,10 +4,10 @@ import { Document } from "../documents/index.js"; import { CallbackManagerForChainRun } from "../callbacks/manager.js"; import { ChatPromptValue, StringPromptValue } from "../prompt_values.js"; import { - LogStreamCallbackHandler, RunLogPatch, type LogStreamCallbackHandlerInput, type StreamEvent, + RunLog, } from "../tracers/log_stream.js"; import { AIMessage, @@ -472,20 +472,16 @@ export class RemoteRunnable< ): AsyncGenerator { const [config, kwargs] = this._separateRunnableConfigFromCallOptions(options); - const stream = new LogStreamCallbackHandler({ - ...streamOptions, - autoClose: false, - }); - const { callbacks } = config; - if (callbacks === undefined) { - config.callbacks = [stream]; - } else if (Array.isArray(callbacks)) { - config.callbacks = callbacks.concat([stream]); - } else { - const copiedCallbacks = callbacks.copy(); - copiedCallbacks.inheritableHandlers.push(stream); - config.callbacks = copiedCallbacks; - } + const callbackManager_ = await getCallbackManagerForConfig(options); + const runManager = await callbackManager_?.handleChainStart( + this.toJSON(), + _coerceToDict(input, "input"), + undefined, + undefined, + undefined, + undefined, + options?.runName + ); // The type is in camelCase but the API only accepts snake_case. const camelCaseStreamOptions = { include_names: streamOptions?.includeNames, @@ -495,32 +491,46 @@ export class RemoteRunnable< exclude_types: streamOptions?.excludeTypes, exclude_tags: streamOptions?.excludeTags, }; - const response = await this.post<{ - input: RunInput; - config?: RunnableConfig; - kwargs?: Omit, keyof RunnableConfig>; - diff: false; - }>("/stream_log", { - input, - config: removeCallbacks(config), - kwargs, - ...camelCaseStreamOptions, - diff: false, - }); - const { body, ok } = response; - if (!ok) { - throw new Error(`${response.status} Error: ${await response.text()}`); - } - if (!body) { - throw new Error( - "Could not begin remote stream log. Please check the given URL and try again." - ); - } - const runnableStream = convertEventStreamToIterableReadableDataStream(body); - for await (const log of runnableStream) { - const chunk = revive(JSON.parse(log)); - yield new RunLogPatch({ ops: chunk.ops }); + let runLog; + try { + const response = await this.post<{ + input: RunInput; + config?: RunnableConfig; + kwargs?: Omit, keyof RunnableConfig>; + diff: false; + }>("/stream_log", { + input, + config: removeCallbacks(config), + kwargs, + ...camelCaseStreamOptions, + diff: false, + }); + const { body, ok } = response; + if (!ok) { + throw new Error(`${response.status} Error: ${await response.text()}`); + } + if (!body) { + throw new Error( + "Could not begin remote stream log. Please check the given URL and try again." + ); + } + const runnableStream = + convertEventStreamToIterableReadableDataStream(body); + for await (const log of runnableStream) { + const chunk = revive(JSON.parse(log)); + const logPatch = new RunLogPatch({ ops: chunk.ops }); + yield logPatch; + if (runLog === undefined) { + runLog = RunLog.fromRunLogPatch(logPatch); + } else { + runLog = runLog.concat(logPatch); + } + } + } catch (err) { + await runManager?.handleChainError(err); + throw err; } + await runManager?.handleChainEnd(runLog?.state.final_output); } async *streamEvents( @@ -535,6 +545,16 @@ export class RemoteRunnable< } const [config, kwargs] = this._separateRunnableConfigFromCallOptions(options); + const callbackManager_ = await getCallbackManagerForConfig(options); + const runManager = await callbackManager_?.handleChainStart( + this.toJSON(), + _coerceToDict(input, "input"), + undefined, + undefined, + undefined, + undefined, + options?.runName + ); // The type is in camelCase but the API only accepts snake_case. const camelCaseStreamOptions = { include_names: streamOptions?.includeNames, @@ -544,38 +564,48 @@ export class RemoteRunnable< exclude_types: streamOptions?.excludeTypes, exclude_tags: streamOptions?.excludeTags, }; - const response = await this.post<{ - input: RunInput; - config?: RunnableConfig; - kwargs?: Omit, keyof RunnableConfig>; - diff: false; - }>("/stream_events", { - input, - config: removeCallbacks(config), - kwargs, - ...camelCaseStreamOptions, - diff: false, - }); - const { body, ok } = response; - if (!ok) { - throw new Error(`${response.status} Error: ${await response.text()}`); - } - if (!body) { - throw new Error( - "Could not begin remote stream events. Please check the given URL and try again." - ); - } - const runnableStream = convertEventStreamToIterableReadableDataStream(body); - for await (const log of runnableStream) { - const chunk = revive(JSON.parse(log)); - yield { - event: chunk.event, - name: chunk.name, - run_id: chunk.run_id, - tags: chunk.tags, - metadata: chunk.metadata, - data: chunk.data, - }; + const events = []; + try { + const response = await this.post<{ + input: RunInput; + config?: RunnableConfig; + kwargs?: Omit, keyof RunnableConfig>; + diff: false; + }>("/stream_events", { + input, + config: removeCallbacks(config), + kwargs, + ...camelCaseStreamOptions, + diff: false, + }); + const { body, ok } = response; + if (!ok) { + throw new Error(`${response.status} Error: ${await response.text()}`); + } + if (!body) { + throw new Error( + "Could not begin remote stream events. Please check the given URL and try again." + ); + } + const runnableStream = + convertEventStreamToIterableReadableDataStream(body); + for await (const log of runnableStream) { + const chunk = revive(JSON.parse(log)); + const event = { + event: chunk.event, + name: chunk.name, + run_id: chunk.run_id, + tags: chunk.tags, + metadata: chunk.metadata, + data: chunk.data, + }; + yield event; + events.push(event); + } + } catch (err) { + await runManager?.handleChainError(err); + throw err; } + await runManager?.handleChainEnd(events); } } diff --git a/langchain-core/src/runnables/tests/runnable_remote.test.ts b/langchain-core/src/runnables/tests/runnable_remote.test.ts index 357672587b0c..a53c661a128b 100644 --- a/langchain-core/src/runnables/tests/runnable_remote.test.ts +++ b/langchain-core/src/runnables/tests/runnable_remote.test.ts @@ -219,8 +219,10 @@ describe("RemoteRunnable", () => { test("Streaming in a chain with model output", async () => { const remote = new RemoteRunnable({ url: `${BASE_URL}/b` }); - const prompt = PromptTemplate.fromTemplate(''); - const chunks = await prompt.pipe(remote).stream({ text: "What are the 5 best apples?" }); + const prompt = PromptTemplate.fromTemplate(""); + const chunks = await prompt + .pipe(remote) + .stream({ text: "What are the 5 best apples?" }); let chunkCount = 0; let accumulator: AIMessageChunk | null = null; for await (const chunk of chunks) {