Skip to content

Commit

Permalink
Node executor has access to datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
abrenneke committed Sep 19, 2023
1 parent 0079941 commit 1e91f99
Show file tree
Hide file tree
Showing 11 changed files with 55 additions and 13 deletions.
7 changes: 6 additions & 1 deletion packages/app-executor/bin/executor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@ import {
NodeRegistration,
plugins as rivetPlugins,
registerBuiltInNodes,
NodeDatasetProvider,
} from '@ironclad/rivet-node';
import * as Rivet from '@ironclad/rivet-core';
import { RivetPluginInitializer } from '@ironclad/rivet-core';
import { InMemoryDatasetProvider, RivetPluginInitializer } from '@ironclad/rivet-core';
import yargs from 'yargs';
import { hideBin } from 'yargs/helpers';
import { P, match } from 'ts-pattern';
Expand Down Expand Up @@ -106,12 +107,16 @@ const rivetDebugger = startDebuggerServer({
}

try {
console.dir({ currentDebuggerState });

const datasetProvider = new InMemoryDatasetProvider(currentDebuggerState.datsets ?? []);
const processor = createProcessor(project, {
graph: graphId,
inputs,
...currentDebuggerState.settings!,
remoteDebugger: rivetDebugger,
registry,
datasetProvider,
onTrace: (trace) => {
console.log(trace);
},
Expand Down
7 changes: 3 additions & 4 deletions packages/app/src/components/DataStudio.tsx
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import { FC, useEffect, useState } from 'react';
import { FC, useState } from 'react';
import { useRecoilState, useRecoilValue } from 'recoil';
import { overlayOpenState } from '../state/ui';
import { css } from '@emotion/react';
import { projectState } from '../state/savedGraphs';
import { ErrorBoundary } from 'react-error-boundary';
import useIndexedDb from '../hooks/useIndexedDb';
import { selectedDatasetState } from '../state/dataStudio';
import { toast } from 'react-toastify';
import { Dataset, DatasetId, DatasetMetadata, DatasetRow, getError, newId } from '@ironclad/rivet-core';
import { DatasetId, DatasetMetadata, DatasetRow, getError, newId } from '@ironclad/rivet-core';
import Button from '@atlaskit/button';
import TextField from '@atlaskit/textfield';
import clsx from 'clsx';
Expand All @@ -27,7 +26,7 @@ export const DataStudioRenderer: FC = () => {
if (openOverlay !== 'dataStudio') return null;

return (
<ErrorBoundary fallback={null}>
<ErrorBoundary fallbackRender={() => 'Failed to render Data Studio'}>
<DataStudio onClose={() => setOpenOverlay(undefined)} />
</ErrorBoundary>
);
Expand Down
5 changes: 3 additions & 2 deletions packages/app/src/components/PromptDesigner.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ import { toast } from 'react-toastify';
import { produce } from 'immer';
import { overlayOpenState } from '../state/ui';
import { BrowserDatasetProvider } from '../io/BrowserDatasetProvider';
import { datasetProvider } from '../utils/globals';

const styles = css`
position: fixed;
Expand Down Expand Up @@ -987,7 +988,7 @@ async function runAdHocChat(messages: ChatMessage[], config: AdHocChatConfig) {
createSubProcessor: undefined!,
settings,
nativeApi: new TauriNativeApi(),
datasetProvider: new BrowserDatasetProvider(),
datasetProvider,
processId: nanoid() as ProcessId,
executionCache: new Map(),
externalFunctions: {},
Expand Down Expand Up @@ -1045,7 +1046,7 @@ function useRunTestGroup() {
const outputs = await processor.processGraph(
{
nativeApi: new TauriNativeApi(),
datasetProvider: new BrowserDatasetProvider(),
datasetProvider,
settings,
},
{
Expand Down
5 changes: 3 additions & 2 deletions packages/app/src/hooks/useLocalExecutor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import { fillMissingSettingsFromEnvironmentVariables } from '../utils/tauri';
import { trivetState } from '../state/trivet';
import { runTrivet } from '@ironclad/trivet';
import { BrowserDatasetProvider } from '../io/BrowserDatasetProvider';
import { datasetProvider } from '../utils/globals';

export function useLocalExecutor() {
const project = useRecoilValue(projectState);
Expand Down Expand Up @@ -129,7 +130,7 @@ export function useLocalExecutor() {
globalRivetNodeRegistry.getPlugins(),
),
nativeApi: new TauriNativeApi(),
datasetProvider: new BrowserDatasetProvider(),
datasetProvider,
});
}

Expand Down Expand Up @@ -187,7 +188,7 @@ export function useLocalExecutor() {
globalRivetNodeRegistry.getPlugins(),
),
nativeApi: new TauriNativeApi(),
datasetProvider: new BrowserDatasetProvider(),
datasetProvider,
},
inputs,
);
Expand Down
4 changes: 4 additions & 0 deletions packages/app/src/hooks/useRemoteExecutor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import {
ProcessEvents,
StringArrayDataValue,
globalRivetNodeRegistry,
serializeDatasets,
} from '@ironclad/rivet-core';
import { useCurrentExecution } from './useCurrentExecution';
import { useRecoilState, useRecoilValue, useSetRecoilState } from 'recoil';
Expand All @@ -21,6 +22,7 @@ import { userInputModalQuestionsState, userInputModalSubmitState } from '../stat
import { pluginsState } from '../state/plugins';
import { entries } from '../../../core/src/utils/typeSafety';
import { selectedExecutorState } from '../state/execution';
import { datasetProvider } from '../utils/globals';

// TODO: This allows us to retrieve the GraphOutputs from the remote debugger.
// If the remote debugger events had a unique ID for each run, this would feel a lot less hacky.
Expand Down Expand Up @@ -146,6 +148,7 @@ export function useRemoteExecutor() {
savedSettings,
globalRivetNodeRegistry.getPlugins(),
),
datasets: serializeDatasets(await datasetProvider.exportDatasetsForProject(project.metadata.id)),
});

for (const [id, dataValue] of entries(projectData)) {
Expand Down Expand Up @@ -208,6 +211,7 @@ export function useRemoteExecutor() {
savedSettings,
globalRivetNodeRegistry.getPlugins(),
),
datasets: await datasetProvider.exportDatasetsForProject(project.metadata.id),
});
}

Expand Down
10 changes: 9 additions & 1 deletion packages/app/src/io/BrowserDatasetProvider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ export class BrowserDatasetProvider implements DatasetProvider {
};
cursorRequest.onsuccess = () => {
const cursor = cursorRequest.result;
if (cursor) {
if (cursor?.value) {
const dataset = cursor.value as DatasetMetadata;
if (dataset.projectId === projectId) {
metadata.push(dataset);
Expand Down Expand Up @@ -127,6 +127,14 @@ export class BrowserDatasetProvider implements DatasetProvider {

if (matchingDataset) {
matchingDataset.meta = metadata;
} else {
this.#currentProjectDatasets.push({
meta: metadata,
data: {
id: metadata.id,
rows: [],
},
});
}

// Sync the database
Expand Down
8 changes: 7 additions & 1 deletion packages/core/src/model/nodes/ReadDirectoryNode.ts
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,12 @@ export class ReadDirectoryNodeImpl extends NodeImpl<ReadDirectoryNode> {
}

async process(inputData: Inputs, context: InternalProcessContext): Promise<Outputs> {
const { nativeApi } = context;

if (nativeApi == null) {
throw new Error('This node requires a native API to run.');
}

const path = this.chartNode.data.usePathInput
? expectType(inputData['path' as PortId], 'string')
: this.chartNode.data.path;
Expand Down Expand Up @@ -181,7 +187,7 @@ export class ReadDirectoryNodeImpl extends NodeImpl<ReadDirectoryNode> {
}

try {
const files = await context.nativeApi.readdir(path, undefined, {
const files = await nativeApi.readdir(path, undefined, {
recursive,
includeDirectories,
filterGlobs,
Expand Down
8 changes: 7 additions & 1 deletion packages/core/src/model/nodes/ReadFileNode.ts
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,18 @@ export class ReadFileNodeImpl extends NodeImpl<ReadFileNode> {
inputData: Record<PortId, DataValue>,
context: InternalProcessContext,
): Promise<Record<PortId, DataValue>> {
const { nativeApi } = context;

if (nativeApi == null) {
throw new Error('This node requires a native API to run.');
}

const path = this.chartNode.data.usePathInput
? expectType(inputData['path' as PortId], 'string')
: this.chartNode.data.path;

try {
const content = await context.nativeApi.readTextFile(path, undefined);
const content = await nativeApi.readTextFile(path, undefined);
return {
['content' as PortId]: { type: 'string', value: content },
};
Expand Down
3 changes: 3 additions & 0 deletions packages/node/src/api.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import {
DataValue,
DatasetProvider,
ExternalFunction,
GraphId,
GraphProcessor,
Expand Down Expand Up @@ -47,6 +48,7 @@ export type RunGraphOptions = {
context?: Record<string, LooseDataValue>;
remoteDebugger?: RivetDebuggerServer;
nativeApi?: NativeApi;
datasetProvider?: DatasetProvider;
externalFunctions?: {
[key: string]: ExternalFunction;
};
Expand Down Expand Up @@ -193,6 +195,7 @@ export function createProcessor(project: Project, options: RunGraphOptions) {
const outputs = await processor.processGraph(
{
nativeApi: options.nativeApi ?? new NodeNativeApi(),
datasetProvider: options.datasetProvider,
settings: {
openAiKey: options.openAiKey ?? '',
openAiOrganization: options.openAiOrganization ?? '',
Expand Down
10 changes: 9 additions & 1 deletion packages/node/src/debugger.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import {
NodeId,
StringArrayDataValue,
DataId,
deserializeDatasets,
CombinedDataset,
} from '@ironclad/rivet-core';
import { match } from 'ts-pattern';
import Emittery from 'emittery';
Expand All @@ -32,6 +34,7 @@ export interface DebuggerEvents {
export const currentDebuggerState = {
uploadedProject: undefined as Project | undefined,
settings: undefined as Settings | undefined,
datsets: [] as CombinedDataset[] | undefined,
};

export type DynamicGraphRunOptions = {
Expand Down Expand Up @@ -90,9 +93,14 @@ export function startDebuggerServer(
})
.with({ type: 'set-dynamic-data' }, async () => {
if (options.allowGraphUpload) {
const { project, settings } = message.data as { project: Project; settings: Settings };
const { project, settings, datasets } = message.data as {
project: Project;
settings: Settings;
datasets: string;
};
currentDebuggerState.uploadedProject = project;
currentDebuggerState.settings = settings;
currentDebuggerState.datsets = deserializeDatasets(datasets);
}
})
.otherwise(async () => {
Expand Down
1 change: 1 addition & 0 deletions packages/node/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ export * from '@ironclad/rivet-core';
export * from './native/NodeNativeApi.js';
export * from './api.js';
export * from './debugger.js';
export * from './native/NodeDatasetProvider.js';

0 comments on commit 1e91f99

Please sign in to comment.