From 0079941b59cbbb8e20461a79e22ee23737225fee Mon Sep 17 00:00:00 2001 From: Andy Brenneke Date: Tue, 19 Sep 2023 15:02:14 -0700 Subject: [PATCH] Persist datasets to data file --- packages/app/src-tauri/src/main.rs | 27 ++- packages/app/src/hooks/useDatasets.ts | 11 +- packages/app/src/hooks/useLoadProject.ts | 47 ++-- packages/app/src/hooks/useSaveProject.ts | 25 +- packages/app/src/io/BrowserDatasetProvider.ts | 217 ++++++++++++------ packages/app/src/io/TauriIOProvider.ts | 60 ++++- packages/app/src/utils/tauri.ts | 4 + .../core/src/integrations/DatasetProvider.ts | 108 ++++++++- .../src/utils/serialization/serialization.ts | 17 +- .../utils/serialization/serialization_v4.ts | 25 ++ .../node/src/native/NodeDatasetProvider.ts | 24 ++ 11 files changed, 456 insertions(+), 109 deletions(-) create mode 100644 packages/node/src/native/NodeDatasetProvider.ts diff --git a/packages/app/src-tauri/src/main.rs b/packages/app/src-tauri/src/main.rs index ace8a9237..5aabc06bf 100644 --- a/packages/app/src-tauri/src/main.rs +++ b/packages/app/src-tauri/src/main.rs @@ -3,7 +3,9 @@ windows_subsystem = "windows" )] -use tauri::{CustomMenuItem, Menu, MenuItem, Submenu}; +use std::path::Path; + +use tauri::{AppHandle, CustomMenuItem, InvokeError, Manager, Menu, MenuItem, Submenu}; mod plugins; fn main() { @@ -12,7 +14,8 @@ fn main() { .plugin(tauri_plugin_window_state::Builder::default().build()) .invoke_handler(tauri::generate_handler![ get_environment_variable, - plugins::extract_package_plugin_tarball + plugins::extract_package_plugin_tarball, + allow_data_file_scope ]) .menu(create_menu()) .on_menu_event(|event| match event.menu_item_id() { @@ -34,6 +37,26 @@ fn get_environment_variable(name: &str) -> String { std::env::var(name).unwrap_or_default() } +#[tauri::command] +fn allow_data_file_scope( + app_handle: AppHandle, + project_file_path: &str, +) -> Result<(), InvokeError> { + let scope = app_handle.fs_scope(); + + let folder_path = Path::new(project_file_path).parent().unwrap(); + let file_name_no_extension = Path::new(project_file_path) + .file_stem() + .unwrap() + .to_str() + .unwrap(); + let data_file_path = folder_path.join(format!("{}.rivet-data", file_name_no_extension)); + + scope.allow_file(&data_file_path)?; + + Ok(()) +} + fn create_menu() -> Menu { let about_menu = Submenu::new( "App", diff --git a/packages/app/src/hooks/useDatasets.ts b/packages/app/src/hooks/useDatasets.ts index f4b1b8911..048ab6f8f 100644 --- a/packages/app/src/hooks/useDatasets.ts +++ b/packages/app/src/hooks/useDatasets.ts @@ -8,6 +8,15 @@ import { useRecoilState } from 'recoil'; export function useDatasets(projectId: ProjectId) { const [datasets, updateDatasets] = useRecoilState(datasetsState); + const initDatasets = async () => { + try { + await datasetProvider.loadDatasets(projectId); + await reloadDatasets(); + } catch (err) { + toast.error(getError(err).message); + } + }; + const reloadDatasets = async () => { try { const datasets = await datasetProvider.getDatasetsForProject(projectId); @@ -18,7 +27,7 @@ export function useDatasets(projectId: ProjectId) { }; useEffect(() => { - reloadDatasets(); + initDatasets(); }, [projectId]); const putDataset = async (dataset: DatasetMetadata) => { diff --git a/packages/app/src/hooks/useLoadProject.ts b/packages/app/src/hooks/useLoadProject.ts index b9b9b3a46..22e3274e0 100644 --- a/packages/app/src/hooks/useLoadProject.ts +++ b/packages/app/src/hooks/useLoadProject.ts @@ -1,12 +1,11 @@ import { useSetRecoilState } from 'recoil'; import { loadedProjectState, projectDataState, projectState } from '../state/savedGraphs.js'; -import { emptyNodeGraph } from '@ironclad/rivet-core'; +import { emptyNodeGraph, getError } from '@ironclad/rivet-core'; import { graphState } from '../state/graph.js'; import { ioProvider } from '../utils/globals.js'; import { trivetState } from '../state/trivet.js'; -import { useStaticDataDatabase } from './useStaticDataDatabase'; -import { entries } from '../../../core/src/utils/typeSafety'; import { useSetStaticData } from './useSetStaticData'; +import { toast } from 'react-toastify'; export function useLoadProject() { const setProject = useSetRecoilState(projectState); @@ -15,30 +14,34 @@ export function useLoadProject() { const setTrivetState = useSetRecoilState(trivetState); const setStaticData = useSetStaticData(); - return () => { - ioProvider.loadProjectData(({ project, testData, path }) => { - const { data, ...projectData } = project; + return async () => { + try { + await ioProvider.loadProjectData(({ project, testData, path }) => { + const { data, ...projectData } = project; - setProject(projectData); + setProject(projectData); - if (data) { - setStaticData(data); - } + if (data) { + setStaticData(data); + } - setGraphData(emptyNodeGraph()); + setGraphData(emptyNodeGraph()); - setLoadedProjectState({ - path, - loaded: true, - }); + setLoadedProjectState({ + path, + loaded: true, + }); - setTrivetState({ - testSuites: testData.testSuites, - selectedTestSuiteId: undefined, - editingTestCaseId: undefined, - recentTestResults: undefined, - runningTests: false, + setTrivetState({ + testSuites: testData.testSuites, + selectedTestSuiteId: undefined, + editingTestCaseId: undefined, + recentTestResults: undefined, + runningTests: false, + }); }); - }); + } catch (err) { + toast.error(`Failed to load project: ${getError(err).message}`); + } }; } diff --git a/packages/app/src/hooks/useSaveProject.ts b/packages/app/src/hooks/useSaveProject.ts index 65ad7b4d3..b207542fa 100644 --- a/packages/app/src/hooks/useSaveProject.ts +++ b/packages/app/src/hooks/useSaveProject.ts @@ -2,7 +2,7 @@ import { useRecoilState, useRecoilValue } from 'recoil'; import { loadedProjectState, projectState } from '../state/savedGraphs.js'; import { useSaveCurrentGraph } from './useSaveCurrentGraph.js'; import { produce } from 'immer'; -import { toast } from 'react-toastify'; +import { toast, Id as ToastId } from 'react-toastify'; import { ioProvider } from '../utils/globals.js'; import { trivetState } from '../state/trivet.js'; @@ -23,7 +23,19 @@ export function useSaveProject() { draft.graphs[savedGraph.metadata!.id!] = savedGraph; }); + // Large datasets can save slowly because of indexeddb, so show a "saving..." toast if it's a slow save + let saving: ToastId | undefined; + const savingTimeout = setTimeout(() => { + saving = toast.info('Saving project'); + }, 500); + await ioProvider.saveProjectDataNoPrompt(newProject, { testSuites }, loadedProject.path); + + if (saving != null) { + toast.dismiss(saving); + } + clearTimeout(savingTimeout); + toast.success('Project saved'); setLoadedProject({ loaded: true, @@ -38,8 +50,19 @@ export function useSaveProject() { draft.graphs[savedGraph.metadata!.id!] = savedGraph; }); + // Large datasets can save slowly because of indexeddb, so show a "saving..." toast if it's a slow save + let saving: ToastId | undefined; + const savingTimeout = setTimeout(() => { + saving = toast.info('Saving project'); + }, 500); + const filePath = await ioProvider.saveProjectData(newProject, { testSuites }); + if (saving != null) { + toast.dismiss(saving); + } + clearTimeout(savingTimeout); + if (filePath) { toast.success('Project saved'); setLoadedProject({ diff --git a/packages/app/src/io/BrowserDatasetProvider.ts b/packages/app/src/io/BrowserDatasetProvider.ts index c743b5f5a..f615bce9b 100644 --- a/packages/app/src/io/BrowserDatasetProvider.ts +++ b/packages/app/src/io/BrowserDatasetProvider.ts @@ -1,6 +1,18 @@ -import { DatasetRow, DatasetId, DatasetMetadata, DatasetProvider, ProjectId, Dataset } from '@ironclad/rivet-core'; +import { + DatasetRow, + DatasetId, + DatasetMetadata, + DatasetProvider, + ProjectId, + Dataset, + CombinedDataset, +} from '@ironclad/rivet-core'; +import { cloneDeep } from 'lodash-es'; export class BrowserDatasetProvider implements DatasetProvider { + currentProjectId: ProjectId | undefined; + #currentProjectDatasets: CombinedDataset[] = []; + async getDatasetDatabase(): Promise { return new Promise((resolve, reject) => { const openRequest = window.indexedDB.open('datasets', 2); @@ -26,116 +38,144 @@ export class BrowserDatasetProvider implements DatasetProvider { }); } - async getDatasetMetadata(id: DatasetId): Promise { - const metadataStore = await this.getDatasetDatabase(); + async loadDatasets(projectId: ProjectId): Promise { + const db = await this.getDatasetDatabase(); - const transaction = metadataStore.transaction('datasets', 'readonly'); - const store = transaction.objectStore('datasets'); - const request = store.get(id); - return new Promise((resolve, reject) => { - request.onsuccess = () => resolve(request.result); - request.onerror = reject; + const store = db.transaction('datasets', 'readonly').objectStore('datasets'); + + const metadata: DatasetMetadata[] = []; + + await new Promise((resolve, reject) => { + const cursorRequest = store.openCursor(); + cursorRequest.onerror = () => { + reject(cursorRequest.error); + }; + cursorRequest.onsuccess = () => { + const cursor = cursorRequest.result; + if (cursor) { + const dataset = cursor.value as DatasetMetadata; + if (dataset.projectId === projectId) { + metadata.push(dataset); + } + cursor.continue(); + } else { + resolve(); + } + }; }); + + const dataStore = db.transaction('data', 'readonly').objectStore('data'); + + const data = await Promise.all( + metadata.map(async (meta) => { + const dataset = await toPromise(dataStore.get(meta.id)); + return dataset; + }), + ); + + this.currentProjectId = projectId; + this.#currentProjectDatasets = metadata.map( + (meta, i): CombinedDataset => ({ + meta, + data: data[i] ?? { + id: meta.id, + rows: [], + }, + }), + ); } - async getDatasetsForProject(projectId: ProjectId): Promise { - const metadataStore = await this.getDatasetDatabase(); + async getDatasetMetadata(id: DatasetId): Promise { + return this.#currentProjectDatasets.find((d) => d.meta.id === id)?.meta; + } - const transaction = metadataStore.transaction('datasets', 'readonly'); - const store = transaction.objectStore('datasets'); - const request = store.getAll(); - return new Promise((resolve, reject) => { - request.onsuccess = () => { - const datasets = request.result as DatasetMetadata[]; + async getDatasetsForProject(projectId: ProjectId): Promise { + if (this.currentProjectId !== projectId) { + throw new Error('Project not loaded. Call loadDatasets first.'); + } - return resolve(datasets.filter((d) => d.projectId === projectId)); - }; - request.onerror = reject; - }); + return this.#currentProjectDatasets.map((d) => d.meta); } async getDatasetData(id: DatasetId): Promise { - const dataStore = await this.getDatasetDatabase(); - - const transaction = dataStore.transaction('data', 'readonly'); - const store = transaction.objectStore('data'); - const request = store.get(id); - return new Promise((resolve, reject) => { - request.onsuccess = () => { - const dataset = request.result as Dataset | null; - return resolve( - dataset ?? { - id, - rows: [], - }, - ); - }; - request.onerror = reject; - }); + return ( + this.#currentProjectDatasets.find((d) => d.meta.id === id)?.data ?? { + id, + rows: [], + } + ); } async putDatasetData(id: DatasetId, data: Dataset): Promise { + const dataset = this.#currentProjectDatasets.find((d) => d.meta.id === id); + if (!dataset) { + throw new Error(`Dataset ${id} not found`); + } + + dataset.data = data; + + // Sync the database const dataStore = await this.getDatasetDatabase(); const transaction = dataStore.transaction('data', 'readwrite'); const store = transaction.objectStore('data'); - const request = store.delete(id); - await new Promise((resolve, reject) => { - request.onsuccess = () => resolve(); - request.onerror = reject; - }); - - const putRequest = store.put(data, id); - return new Promise((resolve, reject) => { - putRequest.onsuccess = () => resolve(); - putRequest.onerror = reject; - }); + await toPromise(store.put(data, id)); } async putDatasetMetadata(metadata: DatasetMetadata): Promise { + const matchingDataset = this.#currentProjectDatasets.find((d) => d.meta.id === metadata.id); + + if (matchingDataset) { + matchingDataset.meta = metadata; + } + + // Sync the database const metadataStore = await this.getDatasetDatabase(); const transaction = metadataStore.transaction('datasets', 'readwrite'); const store = transaction.objectStore('datasets'); - const request = store.put(metadata, metadata.id); - await new Promise((resolve, reject) => { - request.onsuccess = () => resolve(); - request.onerror = reject; - }); + await toPromise(store.put(metadata, metadata.id)); } async clearDatasetData(id: DatasetId): Promise { + const dataset = this.#currentProjectDatasets.find((d) => d.meta.id === id); + if (!dataset) { + return; + } + + dataset.data = { + id, + rows: [], + }; + + // Sync the database const dataStore = await this.getDatasetDatabase(); const transaction = dataStore.transaction('data', 'readwrite'); const store = transaction.objectStore('data'); - const request = store.delete(id); - return new Promise((resolve, reject) => { - request.onsuccess = () => resolve(); - request.onerror = reject; - }); + await toPromise(store.delete(id)); } async deleteDataset(id: DatasetId): Promise { + const index = this.#currentProjectDatasets.findIndex((d) => d.meta.id === id); + if (index === -1) { + return; + } + + this.#currentProjectDatasets.splice(index, 1); + + // Sync the database const metadataStore = await this.getDatasetDatabase(); - const transaction = metadataStore.transaction('datasets', 'readwrite'); - const store = transaction.objectStore('datasets'); - const request = store.delete(id); - await new Promise((resolve, reject) => { - request.onsuccess = () => resolve(); - request.onerror = reject; - }); + const metaTxn = metadataStore.transaction('datasets', 'readwrite'); + const store = metaTxn.objectStore('datasets'); + await toPromise(store.delete(id)); const dataStore = await this.getDatasetDatabase(); - const dataTransaction = dataStore.transaction('data', 'readwrite'); - const dataStoreStore = dataTransaction.objectStore('data'); - const dataRequest = dataStoreStore.delete(id); - await new Promise((resolve, reject) => { - dataRequest.onsuccess = () => resolve(); - dataRequest.onerror = reject; - }); + const dataTxn = dataStore.transaction('data', 'readwrite'); + const dataStoreStore = dataTxn.objectStore('data'); + await toPromise(dataStoreStore.delete(id)); } async knnDatasetRows( @@ -155,9 +195,40 @@ export class BrowserDatasetProvider implements DatasetProvider { return sorted.slice(0, k).map((r) => ({ ...r.row, distance: r.similarity })); } + + async exportDatasetsForProject(_projectId: ProjectId): Promise { + return cloneDeep(this.#currentProjectDatasets); + } + + async importDatasetsForProject(projectId: ProjectId, datasets: CombinedDataset[]) { + this.#currentProjectDatasets = datasets; + this.currentProjectId = projectId; + + const db = await this.getDatasetDatabase(); + const transaction = db.transaction(['datasets', 'data'], 'readwrite'); + + const metadataStore = transaction.objectStore('datasets'); + const dataStore = transaction.objectStore('data'); + + await Promise.all( + datasets.map(async (dataset) => { + await Promise.all([ + toPromise(metadataStore.put(dataset.meta, dataset.meta.id)), + toPromise(dataStore.put(dataset.data, dataset.data.id)), + ]); + }), + ); + } } /** OpenAI embeddings are already normalized, so this is equivalent to cosine similarity */ const dotProductSimilarity = (a: number[], b: number[]): number => { return a.reduce((acc, val, i) => acc + val * b[i]!, 0); }; + +function toPromise(request: IDBRequest): Promise { + return new Promise((resolve, reject) => { + request.onsuccess = () => resolve(request.result); + request.onerror = () => reject(request.error); + }); +} diff --git a/packages/app/src/io/TauriIOProvider.ts b/packages/app/src/io/TauriIOProvider.ts index fc9a835da..eacd150ce 100644 --- a/packages/app/src/io/TauriIOProvider.ts +++ b/packages/app/src/io/TauriIOProvider.ts @@ -1,17 +1,23 @@ import { save, open } from '@tauri-apps/api/dialog'; -import { writeFile, readTextFile, readBinaryFile } from '@tauri-apps/api/fs'; +import { writeFile, readTextFile, readBinaryFile, exists } from '@tauri-apps/api/fs'; import { + CombinedDataset, + DatasetProvider, ExecutionRecorder, NodeGraph, Project, + ProjectId, + deserializeDatasets, deserializeGraph, deserializeProject, + serializeDatasets, serializeGraph, serializeProject, } from '@ironclad/rivet-core'; import { IOProvider } from './IOProvider.js'; -import { isInTauri } from '../utils/tauri.js'; +import { allowDataFileNeighbor, isInTauri } from '../utils/tauri.js'; import { SerializedTrivetData, TrivetData, deserializeTrivetData, serializeTrivetData } from '@ironclad/trivet'; +import { datasetProvider } from '../utils/globals'; export class TauriIOProvider implements IOProvider { static isSupported(): boolean { @@ -62,6 +68,8 @@ export class TauriIOProvider implements IOProvider { path: filePath, }); + await saveDatasetsFile(filePath, project); + return filePath; } @@ -77,6 +85,8 @@ export class TauriIOProvider implements IOProvider { contents: data, path, }); + + await saveDatasetsFile(path, project); } async loadGraphData(callback: (graphData: NodeGraph) => void) { @@ -101,7 +111,7 @@ export class TauriIOProvider implements IOProvider { } async loadProjectData(callback: (data: { project: Project; testData: TrivetData; path: string }) => void) { - const path = await open({ + const path = (await open({ filters: [ { name: 'Rivet Project', @@ -112,17 +122,19 @@ export class TauriIOProvider implements IOProvider { directory: false, recursive: false, title: 'Open graph', - }); + })) as string | undefined; if (path) { - const data = await readTextFile(path as string); + const data = await readTextFile(path); const [projectData, attachedData] = deserializeProject(data); const trivetData = attachedData.trivet ? deserializeTrivetData(attachedData.trivet as SerializedTrivetData) : { testSuites: [] }; - callback({ project: projectData, testData: trivetData, path: path as string }); + await loadDatasetsFile(path, projectData); + + callback({ project: projectData, testData: trivetData, path }); } } @@ -218,3 +230,39 @@ export class TauriIOProvider implements IOProvider { return contents; } } + +async function saveDatasetsFile(projectFilePath: string, project: Project) { + await allowDataFileNeighbor(projectFilePath); + + const dataPath = projectFilePath.replace('.rivet-project', '.rivet-data'); + const datasets = await datasetProvider.exportDatasetsForProject(project.metadata.id); + + if (datasets.length > 0 || (await exists(dataPath))) { + const serializedDatasets = serializeDatasets(datasets); + + await writeFile({ + contents: serializedDatasets, + path: dataPath, + }); + } +} + +async function loadDatasetsFile(projectFilePath: string, project: Project) { + await allowDataFileNeighbor(projectFilePath); + + const datasetsFilePath = projectFilePath.replace('.rivet-project', '.rivet-data'); + + const datasetsFileExists = await exists(datasetsFilePath); + + // No data file, so just no datasets + if (!datasetsFileExists) { + await datasetProvider.importDatasetsForProject(project.metadata.id, []); + return; + } + + const fileContents = await readTextFile(datasetsFilePath); + + const datasets = deserializeDatasets(fileContents); + + await datasetProvider.importDatasetsForProject(project.metadata.id, datasets); +} diff --git a/packages/app/src/utils/tauri.ts b/packages/app/src/utils/tauri.ts index 9cdff78a1..1ed8ecba8 100644 --- a/packages/app/src/utils/tauri.ts +++ b/packages/app/src/utils/tauri.ts @@ -54,3 +54,7 @@ export async function fillMissingSettingsFromEnvironmentVariables(settings: Part return fullSettings; } + +export async function allowDataFileNeighbor(projectFilePath: string): Promise { + await invoke('allow_data_file_scope', { projectFilePath }); +} diff --git a/packages/core/src/integrations/DatasetProvider.ts b/packages/core/src/integrations/DatasetProvider.ts index 2312c9e96..8893472c4 100644 --- a/packages/core/src/integrations/DatasetProvider.ts +++ b/packages/core/src/integrations/DatasetProvider.ts @@ -1,8 +1,8 @@ -import { Opaque } from 'type-fest'; -import { Dataset, DatasetId, DatasetMetadata, DatasetRow, ProjectId } from '../index.js'; +import { CombinedDataset, Dataset, DatasetId, DatasetMetadata, DatasetRow, ProjectId } from '../index.js'; +import { cloneDeep } from 'lodash-es'; export interface DatasetProvider { - getDatasetMetadata(id: DatasetId): Promise; + getDatasetMetadata(id: DatasetId): Promise; getDatasetsForProject(projectId: ProjectId): Promise; @@ -18,4 +18,106 @@ export interface DatasetProvider { /** Gets the K nearest neighbor rows to the given vector. */ knnDatasetRows(datasetId: DatasetId, k: number, vector: number[]): Promise<(DatasetRow & { distance?: number })[]>; + + exportDatasetsForProject(projectId: ProjectId): Promise; } + +export class InMemoryDatasetProvider implements DatasetProvider { + #datasets; + + constructor(datasets: CombinedDataset[]) { + this.#datasets = datasets; + } + + async getDatasetMetadata(id: DatasetId): Promise { + const dataset = this.#datasets.find((d) => d.meta.id === id); + return dataset?.meta; + } + + async getDatasetsForProject(projectId: ProjectId): Promise { + return this.#datasets.map((d) => d.meta); + } + + async getDatasetData(id: DatasetId): Promise { + const dataset = this.#datasets.find((d) => d.meta.id === id); + if (!dataset) { + return { id, rows: [] }; + } + return dataset.data; + } + + async putDatasetData(id: DatasetId, data: Dataset): Promise { + const dataset = this.#datasets.find((d) => d.meta.id === id); + if (!dataset) { + throw new Error(`Dataset ${id} not found`); + } + + dataset.data = data; + } + + async putDatasetMetadata(metadata: DatasetMetadata): Promise { + const matchingDataset = this.#datasets.find((d) => d.meta.id === metadata.id); + + if (matchingDataset) { + matchingDataset.meta = metadata; + return; + } + + this.#datasets.push({ + meta: metadata, + data: { + id: metadata.id, + rows: [], + }, + }); + } + + async clearDatasetData(id: DatasetId): Promise { + const dataset = this.#datasets.find((d) => d.meta.id === id); + if (!dataset) { + return; + } + + dataset.data = { + id, + rows: [], + }; + } + + async deleteDataset(id: DatasetId): Promise { + const index = this.#datasets.findIndex((d) => d.meta.id === id); + if (index === -1) { + return; + } + + this.#datasets.splice(index, 1); + } + + async knnDatasetRows( + datasetId: DatasetId, + k: number, + vector: number[], + ): Promise<(DatasetRow & { distance?: number | undefined })[]> { + const allRows = await this.getDatasetData(datasetId); + + const sorted = allRows.rows + .filter((row) => row.embedding != null) + .map((row) => ({ + row, + similarity: dotProductSimilarity(vector, row.embedding!), + })) + .sort((a, b) => b.similarity - a.similarity); + + return sorted.slice(0, k).map((r) => ({ ...r.row, distance: r.similarity })); + } + + async exportDatasetsForProject(_projectId: ProjectId): Promise { + // Cloning is safest... but slow + return cloneDeep(this.#datasets); + } +} + +/** OpenAI embeddings are already normalized, so this is equivalent to cosine similarity */ +const dotProductSimilarity = (a: number[], b: number[]): number => { + return a.reduce((acc, val, i) => acc + val * b[i]!, 0); +}; diff --git a/packages/core/src/utils/serialization/serialization.ts b/packages/core/src/utils/serialization/serialization.ts index 941cbeb8e..e4b260a32 100644 --- a/packages/core/src/utils/serialization/serialization.ts +++ b/packages/core/src/utils/serialization/serialization.ts @@ -1,10 +1,12 @@ // @ts-ignore import * as yaml from 'yaml'; import { graphV3Deserializer, projectV3Deserializer } from './serialization_v3.js'; -import { Project, NodeGraph } from '../../index.js'; +import { Project, NodeGraph, ProjectId, DatasetProvider, Dataset, DatasetMetadata } from '../../index.js'; import { getError } from '../errors.js'; import { AttachedData, yamlProblem } from './serializationUtils.js'; import { + datasetV4Deserializer, + datasetV4Serializer, graphV4Deserializer, graphV4Serializer, projectV4Deserializer, @@ -79,3 +81,16 @@ export function deserializeGraph(serializedGraph: unknown): NodeGraph { } } } + +export type CombinedDataset = { + meta: DatasetMetadata; + data: Dataset; +}; + +export function serializeDatasets(datasets: CombinedDataset[]): string { + return datasetV4Serializer(datasets); +} + +export function deserializeDatasets(serializedDatasets: string): CombinedDataset[] { + return datasetV4Deserializer(serializedDatasets); +} diff --git a/packages/core/src/utils/serialization/serialization_v4.ts b/packages/core/src/utils/serialization/serialization_v4.ts index 7819e8c40..6ee57c332 100644 --- a/packages/core/src/utils/serialization/serialization_v4.ts +++ b/packages/core/src/utils/serialization/serialization_v4.ts @@ -9,6 +9,9 @@ import { PortId, ProjectId, ChartNodeVariant, + DatasetProvider, + Dataset, + CombinedDataset, } from '../../index.js'; import stableStringify from 'safe-stable-stringify'; import * as yaml from 'yaml'; @@ -265,3 +268,25 @@ function fromSerializedConnection(connection: SerializedNodeConnection, nodeId: inputNodeId: inputNodeId as NodeId, }; } + +export function datasetV4Serializer(datasets: CombinedDataset[]): string { + const dataContainer = { + datasets, + }; + + const data = JSON.stringify(dataContainer); + + return data; +} + +export function datasetV4Deserializer(serializedDatasets: string): CombinedDataset[] { + const stringData = serializedDatasets as string; + + const dataContainer = JSON.parse(stringData) as { datasets: CombinedDataset[] }; + + if (!dataContainer.datasets) { + throw new Error('Invalid dataset data'); + } + + return dataContainer.datasets; +} diff --git a/packages/node/src/native/NodeDatasetProvider.ts b/packages/node/src/native/NodeDatasetProvider.ts new file mode 100644 index 000000000..c98448fea --- /dev/null +++ b/packages/node/src/native/NodeDatasetProvider.ts @@ -0,0 +1,24 @@ +import { InMemoryDatasetProvider, deserializeDatasets } from '@ironclad/rivet-core'; +import { readFile } from 'node:fs/promises'; + +export class NodeDatasetProvider extends InMemoryDatasetProvider { + static async fromDatasetsFile(datasetsFilePath: string): Promise { + try { + const fileContents = await readFile(datasetsFilePath, 'utf8'); + const datasets = deserializeDatasets(fileContents); + return new NodeDatasetProvider(datasets); + } catch (err) { + // No data file, so just no datasets + if ((err as any).code === 'ENOENT') { + return new NodeDatasetProvider([]); + } + + throw err; + } + } + + static async fromProjectFile(projectFilePath: string): Promise { + const dataFilePath = projectFilePath.replace(/\.rivet-project$/, '.rivet-data'); + return NodeDatasetProvider.fromDatasetsFile(dataFilePath); + } +}