Skip to content

Commit

Permalink
Add simplified graph event streaming: an async generator, and an SSE …
Browse files Browse the repository at this point in the history
…implementation
  • Loading branch information
abrenneke committed Oct 20, 2023
1 parent 54e5fff commit 329f029
Show file tree
Hide file tree
Showing 14 changed files with 455 additions and 18 deletions.
2 changes: 1 addition & 1 deletion packages/core/.eslintrc.cjs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ module.exports = {
{
files: ['*.ts', '*.tsx'],
parserOptions: {
project: true,
project: './packages/core/tsconfig.eslint.json',
ecmaVersion: 'latest',
sourceType: 'module',
},
Expand Down
2 changes: 1 addition & 1 deletion packages/core/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
"prepack": "yarn build && cp -r ../../LICENSE ../../README.md .",
"publish": "yarn npm publish --access public",
"lint": "eslint --ext .js,.jsx,.ts,.tsx ./src",
"test": "tsx --test test/**/*.ts"
"test": "tsx --test test/**/*.test.ts"
},
"dependencies": {
"@gentrace/core": "^2.1.15",
Expand Down
8 changes: 4 additions & 4 deletions packages/core/src/model/GraphProcessor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,8 @@ export type ProcessEvents = {
[key: `globalSet:${string}`]: ScalarOrArrayDataValue | undefined;
};

export type ProcessEventTuple = {
[key in keyof ProcessEvents]: [key, ProcessEvents[key]];
export type ProcessEvent = {
[P in keyof ProcessEvents]: { type: P } & ProcessEvents[P];
}[keyof ProcessEvents];

export type GraphOutputs = Record<string, DataValue>;
Expand Down Expand Up @@ -447,9 +447,9 @@ export class GraphProcessor {
await this.#emitter.once('resume');
}

async *events(): AsyncGenerator<ProcessEventTuple> {
async *events(): AsyncGenerator<ProcessEvent> {
for await (const [event, data] of this.#emitter.anyEvent()) {
yield [event, data as any];
yield { type: event, ...(data as any) };

if (event === 'finish') {
break;
Expand Down
42 changes: 42 additions & 0 deletions packages/core/test/model/GraphProcessor.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import { it, describe } from 'node:test';
import { strict as assert } from 'node:assert';
import { loadTestGraphInProcessor, testProcessContext } from '../testUtils';

describe('GraphProcessor', () => {
it('Can run passthrough graph', async () => {
const processor = await loadTestGraphInProcessor('Passthrough');

const outputs = await processor.processGraph(testProcessContext(), {
input: {
type: 'string',
value: 'input value',
},
});

assert.deepEqual(outputs.output, {
type: 'string',
value: 'input value',
});
});

it('Can stream graph processor events', async () => {
const processor = await loadTestGraphInProcessor('Passthrough');

processor.processGraph(testProcessContext(), {
input: {
type: 'string',
value: 'input value',
},
});

const eventNames: string[] = [];
for await (const event of processor.events()) {
if (event.type !== 'trace') {
eventNames.push(event.type);
}
}

assert.equal(eventNames[eventNames.length - 2], 'done');
assert.equal(eventNames[eventNames.length - 1], 'finish');
});
});
35 changes: 35 additions & 0 deletions packages/core/test/test-graphs.rivet-project
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
version: 4
data:
attachedData:
trivet:
testSuites: []
version: 1
graphs:
kqaNrBo0WpJ1EOc2hj0zK:
metadata:
description: ""
id: kqaNrBo0WpJ1EOc2hj0zK
name: Passthrough
nodes:
'[Dp5_0MQuZk7_UTdBQGX-P]:graphOutput "Graph Output"':
data:
dataType: string
id: output
visualData: 1185/525/330/10//
'[hHAeA3eIMmdfGFOYeool0]:passthrough "Passthrough"':
outgoingConnections:
- output1->"Graph Output" Dp5_0MQuZk7_UTdBQGX-P/value
visualData: 928/554/205/9//
'[pbV0xrVlBYe1I3Mew3CIT]:graphInput "Graph Input"':
data:
dataType: string
id: input
useDefaultValueInput: false
outgoingConnections:
- data->"Passthrough" hHAeA3eIMmdfGFOYeool0/input1
visualData: 549/516/330/11//
metadata:
description: ""
id: ytCHmBvDFSkCnQ9L7DJLB
title: Untitled Project
plugins: []
41 changes: 41 additions & 0 deletions packages/core/test/testUtils.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import { readFile } from 'node:fs/promises';
import { deserializeProject, GraphProcessor, type ProcessContext, type Project } from '../src/index.js';
import { dirname, join } from 'node:path';
import { fileURLToPath } from 'node:url';

const testDir = dirname(fileURLToPath(import.meta.url));

export async function loadTestGraphs(): Promise<Project> {
return loadProjectFromFile(join(testDir, './test-graphs.rivet-project'));
}

export async function loadTestGraphInProcessor(graphName: string) {
const project = await loadTestGraphs();
const graph = Object.values(project.graphs).find((g) => g.metadata!.name === graphName);

if (!graph) {
throw new Error(`Could not find graph with name ${graphName}`);
}

return new GraphProcessor(project, graph.metadata!.id!);
}

export async function loadProjectFromFile(path: string): Promise<Project> {
const content = await readFile(path, { encoding: 'utf8' });
return loadProjectFromString(content);
}

export function loadProjectFromString(content: string): Project {
const [project] = deserializeProject(content);
return project;
}

export function testProcessContext(): ProcessContext {
return {
settings: {
openAiKey: process.env.OPENAI_API_KEY,
openAiOrganization: process.env.OPENAI_ORG_ID,
openAiEndpoint: process.env.OPENAI_API_ENDPOINT,
},
};
}
4 changes: 4 additions & 0 deletions packages/core/tsconfig.eslint.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"extends": "./tsconfig.json",
"include": ["src", "test"]
}
24 changes: 13 additions & 11 deletions packages/node/.eslintrc.cjs
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
module.exports = {
"extends": "../../.eslintrc.cjs",
"root": true,
"overrides": [{
"files": ["*.ts", "*.tsx"],
"parserOptions": {
"project": true,
"ecmaVersion": "latest",
"sourceType": "module",
}
}]
}
extends: '../../.eslintrc.cjs',
root: true,
overrides: [
{
files: ['*.ts', '*.tsx'],
parserOptions: {
project: './packages/node/tsconfig.eslint.json',
ecmaVersion: 'latest',
sourceType: 'module',
},
},
],
};
3 changes: 2 additions & 1 deletion packages/node/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@
"build:cjs": "rm -rf dist/cjs && tsx ../core/bundle.esbuild.ts",
"prepack": "yarn build && cp -r ../../LICENSE ../../README.md .",
"publish": "yarn npm publish --access public",
"lint": "eslint --ext .js,.jsx,.ts,.tsx ./src"
"lint": "eslint --ext .js,.jsx,.ts,.tsx ./src",
"test": "tsx --test test/**/*.test.ts"
},
"dependencies": {
"@ironclad/rivet-core": "workspace:^",
Expand Down
172 changes: 172 additions & 0 deletions packages/node/src/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ import {
deserializeProject,
globalRivetNodeRegistry,
type AttachedData,
type NodeId,
coerceType,
type PortId,
type GraphOutputs,
type Outputs,
type Inputs,
} from '@ironclad/rivet-core';

import { readFile } from 'node:fs/promises';
Expand All @@ -22,6 +28,7 @@ import type { PascalCase } from 'type-fest';
import { NodeNativeApi } from './native/NodeNativeApi.js';
import { mapValues } from 'lodash-es';
import * as events from 'node:events';
import { match } from 'ts-pattern';

export async function loadProjectFromFile(path: string): Promise<Project> {
const content = await readFile(path, { encoding: 'utf8' });
Expand Down Expand Up @@ -198,6 +205,8 @@ export function createProcessor(project: Project, options: RunGraphOptions) {
processor,
inputs: resolvedInputs,
contextValues: resolvedContextValues,
getEvents: (spec: RivetEventStreamFilterSpec) => getProcessorEvents(processor, spec),
getEventStream: (spec: RivetEventStreamFilterSpec) => getProcessorEventStream(processor, spec),
async run() {
const outputs = await processor.processGraph(
{
Expand Down Expand Up @@ -251,3 +260,166 @@ function getPluginEnvFromProcessEnv(registry?: NodeRegistration) {
}
return pluginEnv;
}

export type RivetEventStreamFilterSpec = {
/** Stream partial output deltas for the specified node IDs or node titles. */
partialOutputs?: string[] | true;

/** Send the graph output when done? */
done?: boolean;

/** If the graph errors, send an error event? */
error?: boolean;

/** Stream node start events for the specified node IDs or node titles. */
nodeStart?: string[] | true;

/** Stream node finish events for the specified nodeIDs or node titles. */
nodeFinish?: string[] | true;
};

/** Map of all possible event names to their data for streaming events. */
export type RivetEventStreamEvent = {
/** Deltas for partial outputs. */
partialOutput: {
nodeId: NodeId;
nodeTitle: string;
delta: string;
};

nodeStart: {
nodeId: NodeId;
nodeTitle: string;
inputs: Inputs;
};

nodeFinish: {
nodeId: NodeId;
nodeTitle: string;
outputs: Outputs;
};

done: {
graphOutput: GraphOutputs;
};

error: {
error: string;
};
};

export type RivetEventStreamEventInfo = {
[P in keyof RivetEventStreamEvent]: {
type: P;
} & RivetEventStreamEvent[P];
}[keyof RivetEventStreamEvent];

/** A simplified way to listen and stream processor events, including filtering. */
export async function* getProcessorEvents(
processor: GraphProcessor,
spec: RivetEventStreamFilterSpec,
): AsyncGenerator<RivetEventStreamEventInfo, void> {
const previousIndexes = new Map<NodeId, number>();
for await (const event of processor.events()) {
if (event.type === 'partialOutput') {
if (
spec.partialOutputs === true ||
!spec.partialOutputs?.includes(event.node.id) ||
!spec.partialOutputs?.includes(event.node.title)
) {
return;
}

const currentOutput = coerceType(event.outputs['response' as PortId], 'string');

const delta = currentOutput.slice(previousIndexes.get(event.node.id) ?? 0);

yield {
type: 'partialOutput',
nodeId: event.node.id,
nodeTitle: event.node.title,
delta,
};

previousIndexes.set(event.node.id, currentOutput.length);
} else if (event.type === 'done') {
if (spec.done) {
yield {
type: 'done',
graphOutput: event.results,
};
}
} else if (event.type === 'error') {
if (spec.error) {
yield {
type: 'error',
error: typeof event.error === 'string' ? event.error : event.error.toString(),
};
}
} else if (event.type === 'nodeStart') {
if (
spec.nodeStart === true ||
spec.nodeStart?.includes(event.node.id) ||
spec.nodeStart?.includes(event.node.title)
) {
yield {
type: 'nodeStart',
inputs: event.inputs,
nodeId: event.node.id,
nodeTitle: event.node.title,
};
}
} else if (event.type === 'nodeFinish') {
if (
spec.nodeFinish === true ||
spec.nodeFinish?.includes(event.node.id) ||
spec.nodeFinish?.includes(event.node.title)
) {
yield {
type: 'nodeFinish',
outputs: event.outputs,
nodeId: event.node.id,
nodeTitle: event.node.title,
};
}
}
}
}

/**
* Creates a ReadableStream for processor events, following the Server-Sent Events protocol.
* https://developer.mozilla.org/en-US/docs/Web/API/EventSource
*
* Includes configuration for what events to send to the client, for example you can stream the partial output deltas
* for specific nodes, and/or the graph output when done.
*/
export function getProcessorEventStream(
processor: GraphProcessor,

/** The spec for what you're streaming to the client */
spec: RivetEventStreamFilterSpec,
) {
const encoder = new TextEncoder();

function sendEvent<T extends keyof RivetEventStreamEvent>(
controller: ReadableStreamDefaultController,
type: T,
data: RivetEventStreamEvent[T],
) {
const event = `event: ${type}\ndata: ${JSON.stringify(data)}\n\n`;
controller.enqueue(encoder.encode(event));
}

return new ReadableStream<Uint8Array>({
async start(controller) {
try {
for await (const event of getProcessorEvents(processor, spec)) {
sendEvent(controller, event.type, event);
}
controller.close();
} catch (err) {
controller.error(err);
}
},
});
}
Loading

0 comments on commit 329f029

Please sign in to comment.