From 329f0291bb0727fe30bc5bcfc9108cae0f48810d Mon Sep 17 00:00:00 2001 From: Andy Brenneke Date: Fri, 20 Oct 2023 12:55:30 -0700 Subject: [PATCH] Add simplified graph event streaming: an async generator, and an SSE implementation --- packages/core/.eslintrc.cjs | 2 +- packages/core/package.json | 2 +- packages/core/src/model/GraphProcessor.ts | 8 +- .../core/test/model/GraphProcessor.test.ts | 42 +++++ packages/core/test/test-graphs.rivet-project | 35 ++++ packages/core/test/testUtils.ts | 41 +++++ packages/core/tsconfig.eslint.json | 4 + packages/node/.eslintrc.cjs | 24 +-- packages/node/package.json | 3 +- packages/node/src/api.ts | 172 ++++++++++++++++++ packages/node/test/api.test.ts | 82 +++++++++ packages/node/test/test-graphs.rivet-project | 35 ++++ packages/node/test/testUtils.ts | 19 ++ packages/node/tsconfig.eslint.json | 4 + 14 files changed, 455 insertions(+), 18 deletions(-) create mode 100644 packages/core/test/model/GraphProcessor.test.ts create mode 100644 packages/core/test/test-graphs.rivet-project create mode 100644 packages/core/test/testUtils.ts create mode 100644 packages/core/tsconfig.eslint.json create mode 100644 packages/node/test/api.test.ts create mode 100644 packages/node/test/test-graphs.rivet-project create mode 100644 packages/node/test/testUtils.ts create mode 100644 packages/node/tsconfig.eslint.json diff --git a/packages/core/.eslintrc.cjs b/packages/core/.eslintrc.cjs index a9f9fcbca..f247b3019 100644 --- a/packages/core/.eslintrc.cjs +++ b/packages/core/.eslintrc.cjs @@ -5,7 +5,7 @@ module.exports = { { files: ['*.ts', '*.tsx'], parserOptions: { - project: true, + project: './packages/core/tsconfig.eslint.json', ecmaVersion: 'latest', sourceType: 'module', }, diff --git a/packages/core/package.json b/packages/core/package.json index 790fff964..f2aaac8f5 100644 --- a/packages/core/package.json +++ b/packages/core/package.json @@ -40,7 +40,7 @@ "prepack": "yarn build && cp -r ../../LICENSE ../../README.md .", "publish": "yarn npm publish --access public", "lint": "eslint --ext .js,.jsx,.ts,.tsx ./src", - "test": "tsx --test test/**/*.ts" + "test": "tsx --test test/**/*.test.ts" }, "dependencies": { "@gentrace/core": "^2.1.15", diff --git a/packages/core/src/model/GraphProcessor.ts b/packages/core/src/model/GraphProcessor.ts index 496407ed2..fd71e81ab 100644 --- a/packages/core/src/model/GraphProcessor.ts +++ b/packages/core/src/model/GraphProcessor.ts @@ -121,8 +121,8 @@ export type ProcessEvents = { [key: `globalSet:${string}`]: ScalarOrArrayDataValue | undefined; }; -export type ProcessEventTuple = { - [key in keyof ProcessEvents]: [key, ProcessEvents[key]]; +export type ProcessEvent = { + [P in keyof ProcessEvents]: { type: P } & ProcessEvents[P]; }[keyof ProcessEvents]; export type GraphOutputs = Record; @@ -447,9 +447,9 @@ export class GraphProcessor { await this.#emitter.once('resume'); } - async *events(): AsyncGenerator { + async *events(): AsyncGenerator { for await (const [event, data] of this.#emitter.anyEvent()) { - yield [event, data as any]; + yield { type: event, ...(data as any) }; if (event === 'finish') { break; diff --git a/packages/core/test/model/GraphProcessor.test.ts b/packages/core/test/model/GraphProcessor.test.ts new file mode 100644 index 000000000..865dc7bb4 --- /dev/null +++ b/packages/core/test/model/GraphProcessor.test.ts @@ -0,0 +1,42 @@ +import { it, describe } from 'node:test'; +import { strict as assert } from 'node:assert'; +import { loadTestGraphInProcessor, testProcessContext } from '../testUtils'; + +describe('GraphProcessor', () => { + it('Can run passthrough graph', async () => { + const processor = await loadTestGraphInProcessor('Passthrough'); + + const outputs = await processor.processGraph(testProcessContext(), { + input: { + type: 'string', + value: 'input value', + }, + }); + + assert.deepEqual(outputs.output, { + type: 'string', + value: 'input value', + }); + }); + + it('Can stream graph processor events', async () => { + const processor = await loadTestGraphInProcessor('Passthrough'); + + processor.processGraph(testProcessContext(), { + input: { + type: 'string', + value: 'input value', + }, + }); + + const eventNames: string[] = []; + for await (const event of processor.events()) { + if (event.type !== 'trace') { + eventNames.push(event.type); + } + } + + assert.equal(eventNames[eventNames.length - 2], 'done'); + assert.equal(eventNames[eventNames.length - 1], 'finish'); + }); +}); diff --git a/packages/core/test/test-graphs.rivet-project b/packages/core/test/test-graphs.rivet-project new file mode 100644 index 000000000..a63e8d6ec --- /dev/null +++ b/packages/core/test/test-graphs.rivet-project @@ -0,0 +1,35 @@ +version: 4 +data: + attachedData: + trivet: + testSuites: [] + version: 1 + graphs: + kqaNrBo0WpJ1EOc2hj0zK: + metadata: + description: "" + id: kqaNrBo0WpJ1EOc2hj0zK + name: Passthrough + nodes: + '[Dp5_0MQuZk7_UTdBQGX-P]:graphOutput "Graph Output"': + data: + dataType: string + id: output + visualData: 1185/525/330/10// + '[hHAeA3eIMmdfGFOYeool0]:passthrough "Passthrough"': + outgoingConnections: + - output1->"Graph Output" Dp5_0MQuZk7_UTdBQGX-P/value + visualData: 928/554/205/9// + '[pbV0xrVlBYe1I3Mew3CIT]:graphInput "Graph Input"': + data: + dataType: string + id: input + useDefaultValueInput: false + outgoingConnections: + - data->"Passthrough" hHAeA3eIMmdfGFOYeool0/input1 + visualData: 549/516/330/11// + metadata: + description: "" + id: ytCHmBvDFSkCnQ9L7DJLB + title: Untitled Project + plugins: [] diff --git a/packages/core/test/testUtils.ts b/packages/core/test/testUtils.ts new file mode 100644 index 000000000..8b4d7e6f6 --- /dev/null +++ b/packages/core/test/testUtils.ts @@ -0,0 +1,41 @@ +import { readFile } from 'node:fs/promises'; +import { deserializeProject, GraphProcessor, type ProcessContext, type Project } from '../src/index.js'; +import { dirname, join } from 'node:path'; +import { fileURLToPath } from 'node:url'; + +const testDir = dirname(fileURLToPath(import.meta.url)); + +export async function loadTestGraphs(): Promise { + return loadProjectFromFile(join(testDir, './test-graphs.rivet-project')); +} + +export async function loadTestGraphInProcessor(graphName: string) { + const project = await loadTestGraphs(); + const graph = Object.values(project.graphs).find((g) => g.metadata!.name === graphName); + + if (!graph) { + throw new Error(`Could not find graph with name ${graphName}`); + } + + return new GraphProcessor(project, graph.metadata!.id!); +} + +export async function loadProjectFromFile(path: string): Promise { + const content = await readFile(path, { encoding: 'utf8' }); + return loadProjectFromString(content); +} + +export function loadProjectFromString(content: string): Project { + const [project] = deserializeProject(content); + return project; +} + +export function testProcessContext(): ProcessContext { + return { + settings: { + openAiKey: process.env.OPENAI_API_KEY, + openAiOrganization: process.env.OPENAI_ORG_ID, + openAiEndpoint: process.env.OPENAI_API_ENDPOINT, + }, + }; +} diff --git a/packages/core/tsconfig.eslint.json b/packages/core/tsconfig.eslint.json new file mode 100644 index 000000000..40e0cfc28 --- /dev/null +++ b/packages/core/tsconfig.eslint.json @@ -0,0 +1,4 @@ +{ + "extends": "./tsconfig.json", + "include": ["src", "test"] +} diff --git a/packages/node/.eslintrc.cjs b/packages/node/.eslintrc.cjs index cbbfa19eb..985c9017a 100644 --- a/packages/node/.eslintrc.cjs +++ b/packages/node/.eslintrc.cjs @@ -1,12 +1,14 @@ module.exports = { - "extends": "../../.eslintrc.cjs", - "root": true, - "overrides": [{ - "files": ["*.ts", "*.tsx"], - "parserOptions": { - "project": true, - "ecmaVersion": "latest", - "sourceType": "module", - } - }] -} + extends: '../../.eslintrc.cjs', + root: true, + overrides: [ + { + files: ['*.ts', '*.tsx'], + parserOptions: { + project: './packages/node/tsconfig.eslint.json', + ecmaVersion: 'latest', + sourceType: 'module', + }, + }, + ], +}; diff --git a/packages/node/package.json b/packages/node/package.json index 4e41444fb..6f11662ca 100644 --- a/packages/node/package.json +++ b/packages/node/package.json @@ -36,7 +36,8 @@ "build:cjs": "rm -rf dist/cjs && tsx ../core/bundle.esbuild.ts", "prepack": "yarn build && cp -r ../../LICENSE ../../README.md .", "publish": "yarn npm publish --access public", - "lint": "eslint --ext .js,.jsx,.ts,.tsx ./src" + "lint": "eslint --ext .js,.jsx,.ts,.tsx ./src", + "test": "tsx --test test/**/*.test.ts" }, "dependencies": { "@ironclad/rivet-core": "workspace:^", diff --git a/packages/node/src/api.ts b/packages/node/src/api.ts index bccc727ca..0cee351f2 100644 --- a/packages/node/src/api.ts +++ b/packages/node/src/api.ts @@ -14,6 +14,12 @@ import { deserializeProject, globalRivetNodeRegistry, type AttachedData, + type NodeId, + coerceType, + type PortId, + type GraphOutputs, + type Outputs, + type Inputs, } from '@ironclad/rivet-core'; import { readFile } from 'node:fs/promises'; @@ -22,6 +28,7 @@ import type { PascalCase } from 'type-fest'; import { NodeNativeApi } from './native/NodeNativeApi.js'; import { mapValues } from 'lodash-es'; import * as events from 'node:events'; +import { match } from 'ts-pattern'; export async function loadProjectFromFile(path: string): Promise { const content = await readFile(path, { encoding: 'utf8' }); @@ -198,6 +205,8 @@ export function createProcessor(project: Project, options: RunGraphOptions) { processor, inputs: resolvedInputs, contextValues: resolvedContextValues, + getEvents: (spec: RivetEventStreamFilterSpec) => getProcessorEvents(processor, spec), + getEventStream: (spec: RivetEventStreamFilterSpec) => getProcessorEventStream(processor, spec), async run() { const outputs = await processor.processGraph( { @@ -251,3 +260,166 @@ function getPluginEnvFromProcessEnv(registry?: NodeRegistration) { } return pluginEnv; } + +export type RivetEventStreamFilterSpec = { + /** Stream partial output deltas for the specified node IDs or node titles. */ + partialOutputs?: string[] | true; + + /** Send the graph output when done? */ + done?: boolean; + + /** If the graph errors, send an error event? */ + error?: boolean; + + /** Stream node start events for the specified node IDs or node titles. */ + nodeStart?: string[] | true; + + /** Stream node finish events for the specified nodeIDs or node titles. */ + nodeFinish?: string[] | true; +}; + +/** Map of all possible event names to their data for streaming events. */ +export type RivetEventStreamEvent = { + /** Deltas for partial outputs. */ + partialOutput: { + nodeId: NodeId; + nodeTitle: string; + delta: string; + }; + + nodeStart: { + nodeId: NodeId; + nodeTitle: string; + inputs: Inputs; + }; + + nodeFinish: { + nodeId: NodeId; + nodeTitle: string; + outputs: Outputs; + }; + + done: { + graphOutput: GraphOutputs; + }; + + error: { + error: string; + }; +}; + +export type RivetEventStreamEventInfo = { + [P in keyof RivetEventStreamEvent]: { + type: P; + } & RivetEventStreamEvent[P]; +}[keyof RivetEventStreamEvent]; + +/** A simplified way to listen and stream processor events, including filtering. */ +export async function* getProcessorEvents( + processor: GraphProcessor, + spec: RivetEventStreamFilterSpec, +): AsyncGenerator { + const previousIndexes = new Map(); + for await (const event of processor.events()) { + if (event.type === 'partialOutput') { + if ( + spec.partialOutputs === true || + !spec.partialOutputs?.includes(event.node.id) || + !spec.partialOutputs?.includes(event.node.title) + ) { + return; + } + + const currentOutput = coerceType(event.outputs['response' as PortId], 'string'); + + const delta = currentOutput.slice(previousIndexes.get(event.node.id) ?? 0); + + yield { + type: 'partialOutput', + nodeId: event.node.id, + nodeTitle: event.node.title, + delta, + }; + + previousIndexes.set(event.node.id, currentOutput.length); + } else if (event.type === 'done') { + if (spec.done) { + yield { + type: 'done', + graphOutput: event.results, + }; + } + } else if (event.type === 'error') { + if (spec.error) { + yield { + type: 'error', + error: typeof event.error === 'string' ? event.error : event.error.toString(), + }; + } + } else if (event.type === 'nodeStart') { + if ( + spec.nodeStart === true || + spec.nodeStart?.includes(event.node.id) || + spec.nodeStart?.includes(event.node.title) + ) { + yield { + type: 'nodeStart', + inputs: event.inputs, + nodeId: event.node.id, + nodeTitle: event.node.title, + }; + } + } else if (event.type === 'nodeFinish') { + if ( + spec.nodeFinish === true || + spec.nodeFinish?.includes(event.node.id) || + spec.nodeFinish?.includes(event.node.title) + ) { + yield { + type: 'nodeFinish', + outputs: event.outputs, + nodeId: event.node.id, + nodeTitle: event.node.title, + }; + } + } + } +} + +/** + * Creates a ReadableStream for processor events, following the Server-Sent Events protocol. + * https://developer.mozilla.org/en-US/docs/Web/API/EventSource + * + * Includes configuration for what events to send to the client, for example you can stream the partial output deltas + * for specific nodes, and/or the graph output when done. + */ +export function getProcessorEventStream( + processor: GraphProcessor, + + /** The spec for what you're streaming to the client */ + spec: RivetEventStreamFilterSpec, +) { + const encoder = new TextEncoder(); + + function sendEvent( + controller: ReadableStreamDefaultController, + type: T, + data: RivetEventStreamEvent[T], + ) { + const event = `event: ${type}\ndata: ${JSON.stringify(data)}\n\n`; + controller.enqueue(encoder.encode(event)); + } + + return new ReadableStream({ + async start(controller) { + try { + for await (const event of getProcessorEvents(processor, spec)) { + sendEvent(controller, event.type, event); + } + controller.close(); + } catch (err) { + controller.error(err); + } + }, + }); +} diff --git a/packages/node/test/api.test.ts b/packages/node/test/api.test.ts new file mode 100644 index 000000000..ff89be30e --- /dev/null +++ b/packages/node/test/api.test.ts @@ -0,0 +1,82 @@ +import { describe, it } from 'node:test'; +import * as assert from 'node:assert/strict'; +import { loadTestGraphs } from './testUtils'; +import { createProcessor } from '../src/index.js'; + +describe('api', () => { + it('can stream processor events', async () => { + const processor = createProcessor(await loadTestGraphs(), { + graph: 'Passthrough', + inputs: { + input: 'input value', + }, + }); + + processor.run(); + + const eventNames: string[] = []; + for await (const event of processor.getEvents({ done: true, nodeStart: true, nodeFinish: true })) { + eventNames.push(event.type); + } + + // 3 nodes start and finish + done + assert.deepEqual(eventNames, [ + 'nodeStart', + 'nodeFinish', + 'nodeStart', + 'nodeFinish', + 'nodeStart', + 'nodeFinish', + 'done', + ]); + }); + + it('can easily filter for a node', async () => { + const processor = createProcessor(await loadTestGraphs(), { + graph: 'Passthrough', + inputs: { + input: 'input value', + }, + }); + + processor.run(); + + for await (const event of processor.getEvents({ nodeStart: ['Passthrough'] })) { + assert.equal(event.type, 'nodeStart'); + } + }); + + it('Can get an event stream for a processor', async () => { + const processor = createProcessor(await loadTestGraphs(), { + graph: 'Passthrough', + inputs: { + input: 'input value', + }, + }); + + processor.run(); + + const reader = processor + .getEventStream({ + nodeFinish: true, + }) + .getReader(); + + const decoder = new TextDecoder(); + + // Kind of a mess but whatev + const eventNames: string[] = []; + while (true) { + const { done, value } = await reader.read(); + if (done) { + break; + } + const data = decoder.decode(value); + + const event = /event: (?.*)/.exec(data)!.groups!.event!; + eventNames.push(event); + } + + assert.deepEqual(eventNames, ['nodeFinish', 'nodeFinish', 'nodeFinish']); + }); +}); diff --git a/packages/node/test/test-graphs.rivet-project b/packages/node/test/test-graphs.rivet-project new file mode 100644 index 000000000..a63e8d6ec --- /dev/null +++ b/packages/node/test/test-graphs.rivet-project @@ -0,0 +1,35 @@ +version: 4 +data: + attachedData: + trivet: + testSuites: [] + version: 1 + graphs: + kqaNrBo0WpJ1EOc2hj0zK: + metadata: + description: "" + id: kqaNrBo0WpJ1EOc2hj0zK + name: Passthrough + nodes: + '[Dp5_0MQuZk7_UTdBQGX-P]:graphOutput "Graph Output"': + data: + dataType: string + id: output + visualData: 1185/525/330/10// + '[hHAeA3eIMmdfGFOYeool0]:passthrough "Passthrough"': + outgoingConnections: + - output1->"Graph Output" Dp5_0MQuZk7_UTdBQGX-P/value + visualData: 928/554/205/9// + '[pbV0xrVlBYe1I3Mew3CIT]:graphInput "Graph Input"': + data: + dataType: string + id: input + useDefaultValueInput: false + outgoingConnections: + - data->"Passthrough" hHAeA3eIMmdfGFOYeool0/input1 + visualData: 549/516/330/11// + metadata: + description: "" + id: ytCHmBvDFSkCnQ9L7DJLB + title: Untitled Project + plugins: [] diff --git a/packages/node/test/testUtils.ts b/packages/node/test/testUtils.ts new file mode 100644 index 000000000..c1334a731 --- /dev/null +++ b/packages/node/test/testUtils.ts @@ -0,0 +1,19 @@ +import { loadProjectFromFile, type ProcessContext, type Project } from '../src/index.js'; +import { dirname, join } from 'node:path'; +import { fileURLToPath } from 'node:url'; + +const testDir = dirname(fileURLToPath(import.meta.url)); + +export async function loadTestGraphs(): Promise { + return loadProjectFromFile(join(testDir, './test-graphs.rivet-project')); +} + +export function testProcessContext(): ProcessContext { + return { + settings: { + openAiKey: process.env.OPENAI_API_KEY, + openAiOrganization: process.env.OPENAI_ORG_ID, + openAiEndpoint: process.env.OPENAI_API_ENDPOINT, + }, + }; +} diff --git a/packages/node/tsconfig.eslint.json b/packages/node/tsconfig.eslint.json new file mode 100644 index 000000000..40e0cfc28 --- /dev/null +++ b/packages/node/tsconfig.eslint.json @@ -0,0 +1,4 @@ +{ + "extends": "./tsconfig.json", + "include": ["src", "test"] +}