diff --git a/packages/node/src/api.ts b/packages/node/src/api.ts index fea7922d5..ebb70d2d0 100644 --- a/packages/node/src/api.ts +++ b/packages/node/src/api.ts @@ -22,7 +22,12 @@ import type { PascalCase } from 'type-fest'; import { NodeNativeApi } from './native/NodeNativeApi.js'; import { mapValues } from 'lodash-es'; import * as events from 'node:events'; -import { getProcessorEventStream, getProcessorEvents, type RivetEventStreamFilterSpec } from './streaming.js'; +import { + getProcessorSSEStream, + getProcessorEvents, + type RivetEventStreamFilterSpec, + getSingleNodeStream, +} from './streaming.js'; export async function loadProjectFromFile(path: string): Promise { const content = await readFile(path, { encoding: 'utf8' }); @@ -200,7 +205,8 @@ export function createProcessor(project: Project, options: RunGraphOptions) { inputs: resolvedInputs, contextValues: resolvedContextValues, getEvents: (spec: RivetEventStreamFilterSpec) => getProcessorEvents(processor, spec), - getEventStream: (spec: RivetEventStreamFilterSpec) => getProcessorEventStream(processor, spec), + getSSEStream: (spec: RivetEventStreamFilterSpec) => getProcessorSSEStream(processor, spec), + streamNode: (nodeIdOrTitle: string) => getSingleNodeStream(processor, nodeIdOrTitle), async run() { const outputs = await processor.processGraph( { diff --git a/packages/node/src/index.ts b/packages/node/src/index.ts index 200e11959..e9c761be7 100644 --- a/packages/node/src/index.ts +++ b/packages/node/src/index.ts @@ -2,6 +2,7 @@ export * from '@ironclad/rivet-core'; export * from './native/NodeNativeApi.js'; export * from './api.js'; +export * from './streaming.js'; export * from './debugger.js'; export * from './native/NodeDatasetProvider.js'; export * from './native/DebuggerDatasetProvider.js'; diff --git a/packages/node/src/streaming.ts b/packages/node/src/streaming.ts index 3229b36c7..e77898f51 100644 --- a/packages/node/src/streaming.ts +++ b/packages/node/src/streaming.ts @@ -67,28 +67,27 @@ export async function* getProcessorEvents( spec: RivetEventStreamFilterSpec, ): AsyncGenerator { const previousIndexes = new Map(); + for await (const event of processor.events()) { if (event.type === 'partialOutput') { if ( spec.partialOutputs === true || - !spec.partialOutputs?.includes(event.node.id) || - !spec.partialOutputs?.includes(event.node.title) + spec.partialOutputs?.includes(event.node.id) || + spec.partialOutputs?.includes(event.node.title) ) { - return; - } - - const currentOutput = coerceType(event.outputs['response' as PortId], 'string'); + const currentOutput = coerceType(event.outputs['response' as PortId], 'string'); - const delta = currentOutput.slice(previousIndexes.get(event.node.id) ?? 0); + const delta = currentOutput.slice(previousIndexes.get(event.node.id) ?? 0); - yield { - type: 'partialOutput', - nodeId: event.node.id, - nodeTitle: event.node.title, - delta, - }; + yield { + type: 'partialOutput', + nodeId: event.node.id, + nodeTitle: event.node.title, + delta, + }; - previousIndexes.set(event.node.id, currentOutput.length); + previousIndexes.set(event.node.id, currentOutput.length); + } } else if (event.type === 'done') { if (spec.done) { yield { @@ -140,7 +139,7 @@ export async function* getProcessorEvents( * Includes configuration for what events to send to the client, for example you can stream the partial output deltas * for specific nodes, and/or the graph output when done. */ -export function getProcessorEventStream( +export function getProcessorSSEStream( processor: GraphProcessor, /** The spec for what you're streaming to the client */ @@ -170,3 +169,29 @@ export function getProcessorEventStream( }, }); } + +export function getSingleNodeStream(processor: GraphProcessor, nodeIdOrTitle: string) { + return new ReadableStream({ + async start(controller) { + try { + for await (const event of getProcessorEvents(processor, { + partialOutputs: [nodeIdOrTitle], + nodeFinish: [nodeIdOrTitle], + })) { + if (event.type === 'partialOutput' && (event.nodeId === nodeIdOrTitle || event.nodeTitle === nodeIdOrTitle)) { + controller.enqueue(`data: ${JSON.stringify(event.delta)}\n\n`); + } else if ( + event.type === 'nodeFinish' && + (event.nodeId === nodeIdOrTitle || event.nodeTitle === nodeIdOrTitle) + ) { + controller.close(); + } + } + + controller.close(); + } catch (err) { + controller.error(err); + } + }, + }); +} diff --git a/packages/node/test/api.test.ts b/packages/node/test/api.test.ts index ff89be30e..b53147d50 100644 --- a/packages/node/test/api.test.ts +++ b/packages/node/test/api.test.ts @@ -57,7 +57,7 @@ describe('api', () => { processor.run(); const reader = processor - .getEventStream({ + .getSSEStream({ nodeFinish: true, }) .getReader();