diff --git a/packages/app/src/components/CodeEditor.tsx b/packages/app/src/components/CodeEditor.tsx index d7cc0d35d..d446360bd 100644 --- a/packages/app/src/components/CodeEditor.tsx +++ b/packages/app/src/components/CodeEditor.tsx @@ -12,9 +12,21 @@ export const CodeEditor: FC<{ theme?: string; autoFocus?: boolean; onKeyDown?: (e: monaco.IKeyboardEvent) => void; + onBlur?: () => void; editorRef?: MutableRefObject; scrollBeyondLastLine?: boolean; -}> = ({ text, isReadonly, onChange, language, theme, autoFocus, onKeyDown, editorRef, scrollBeyondLastLine }) => { +}> = ({ + text, + isReadonly, + onChange, + language, + theme, + autoFocus, + onKeyDown, + onBlur, + editorRef, + scrollBeyondLastLine, +}) => { const editorContainer = useRef(null); const editorInstance = useRef(); @@ -56,6 +68,10 @@ export const CodeEditor: FC<{ onChangeLatest.current?.(editor.getValue()); }); + editor.onDidBlurEditorWidget(() => { + onBlur?.(); + }); + editorInstance.current = editor; if (editorRef) { editorRef.current = editor; diff --git a/packages/app/src/components/DataStudio.tsx b/packages/app/src/components/DataStudio.tsx index c87ac39e2..27e97bfc5 100644 --- a/packages/app/src/components/DataStudio.tsx +++ b/packages/app/src/components/DataStudio.tsx @@ -19,6 +19,7 @@ import { parse as parseCsv } from 'csv-parse/browser/esm/sync'; import { ioProvider } from '../utils/globals'; import { useDataset } from '../hooks/useDataset'; import { useDatasets } from '../hooks/useDatasets'; +import { LazyCodeEditor } from './LazyComponents'; export const DataStudioRenderer: FC = () => { const [openOverlay, setOpenOverlay] = useRecoilState(overlayOpenState); @@ -326,13 +327,21 @@ const datasetDisplayStyles = css` .cell { height: 48px; + + &.editor { + .editor-container { + min-height: 200px; + min-width: 100px; + } + } } .value { padding: 4px 8px; height: 100%; display: flex; - align-items: center; + align-items: flex-start; + overflow: hidden; } } } @@ -503,19 +512,21 @@ const DatasetEditableCell: FC<{ onChange: (value: string) => void; }> = ({ value, row, column, onChange }) => { const [editing, setEditing] = useState(false); + const [editingText, setEditingText] = useState(value); return ( -
+
{editing ? ( - { - onChange((e.target as HTMLInputElement).value); + text={value} + onChange={(e) => setEditingText(e)} + onBlur={() => { + onChange(editingText); setEditing(false); }} onKeyDown={(e) => { - if (e.key === 'Enter') { + if (e.keyCode === 3 && (e.metaKey || e.ctrlKey)) { onChange((e.target as HTMLInputElement).value); setEditing(false); } diff --git a/packages/app/src/components/nodes/DatasetNearestNeighborsNode.tsx b/packages/app/src/components/nodes/DatasetNearestNeighborsNode.tsx new file mode 100644 index 000000000..be82ba3d3 --- /dev/null +++ b/packages/app/src/components/nodes/DatasetNearestNeighborsNode.tsx @@ -0,0 +1,27 @@ +import { DatasetNearestNeighborsNode } from '@ironclad/rivet-core'; +import { NodeComponentDescriptor } from '../../hooks/useNodeTypes'; +import { FC } from 'react'; +import { useDatasets } from '../../hooks/useDatasets'; +import { useRecoilValue } from 'recoil'; +import { projectState } from '../../state/savedGraphs'; + +export const DatasetNearestNeightborsNode: FC<{ + node: DatasetNearestNeighborsNode; +}> = ({ node }) => { + const project = useRecoilValue(projectState); + const { datasets } = useDatasets(project.metadata.id); + + const dataset = datasets?.find((d) => d.id === node.data.datasetId); + + return ( +
+
+ {node.data.useDatasetIdInput ? 'Dataset from input' : dataset ? dataset.name : 'Unknown or no dataset selected'} +
+
+ ); +}; + +export const datasetNearestNeighborsNodeDescriptor: NodeComponentDescriptor<'datasetNearestNeighbors'> = { + Body: DatasetNearestNeightborsNode, +}; diff --git a/packages/app/src/hooks/useNodeTypes.ts b/packages/app/src/hooks/useNodeTypes.ts index 1ab81da56..b81210b4c 100644 --- a/packages/app/src/hooks/useNodeTypes.ts +++ b/packages/app/src/hooks/useNodeTypes.ts @@ -16,6 +16,7 @@ import { appendToDatasetNodeDescriptor } from '../components/nodes/AppendToDatas import { useRecoilValue } from 'recoil'; import { pluginRefreshCounterState } from '../state/plugins'; import { loadDatasetNodeDescriptor } from '../components/nodes/LoadDatasetNode'; +import { datasetNearestNeighborsNodeDescriptor } from '../components/nodes/DatasetNearestNeighborsNode'; export type UnknownNodeComponentDescriptor = { Body?: FC<{ node: ChartNode }>; @@ -55,6 +56,7 @@ const overriddenDescriptors: Partial = { audio: audioNodeDescriptor, appendToDataset: appendToDatasetNodeDescriptor, loadDataset: loadDatasetNodeDescriptor, + datasetNearestNeighbors: datasetNearestNeighborsNodeDescriptor, }; export function useNodeTypes(): NodeComponentDescriptors { diff --git a/packages/app/src/io/BrowserDatasetProvider.ts b/packages/app/src/io/BrowserDatasetProvider.ts index db5ce7462..c743b5f5a 100644 --- a/packages/app/src/io/BrowserDatasetProvider.ts +++ b/packages/app/src/io/BrowserDatasetProvider.ts @@ -137,4 +137,27 @@ export class BrowserDatasetProvider implements DatasetProvider { dataRequest.onerror = reject; }); } + + async knnDatasetRows( + datasetId: DatasetId, + k: number, + vector: number[], + ): Promise<(DatasetRow & { distance?: number })[]> { + const allRows = await this.getDatasetData(datasetId); + + const sorted = allRows.rows + .filter((row) => row.embedding != null) + .map((row) => ({ + row, + similarity: dotProductSimilarity(vector, row.embedding!), + })) + .sort((a, b) => b.similarity - a.similarity); + + return sorted.slice(0, k).map((r) => ({ ...r.row, distance: r.similarity })); + } } + +/** OpenAI embeddings are already normalized, so this is equivalent to cosine similarity */ +const dotProductSimilarity = (a: number[], b: number[]): number => { + return a.reduce((acc, val, i) => acc + val * b[i]!, 0); +}; diff --git a/packages/core/src/exports.ts b/packages/core/src/exports.ts index eadde9de6..f5a87606d 100644 --- a/packages/core/src/exports.ts +++ b/packages/core/src/exports.ts @@ -34,6 +34,7 @@ export * from './utils/inputs.js'; export * from './utils/newId.js'; export * from './utils/misc.js'; export * from './integrations/DatasetProvider.js'; +export * from './model/Dataset.js'; import * as openai from './utils/openai.js'; export { openai }; diff --git a/packages/core/src/integrations/DatasetProvider.ts b/packages/core/src/integrations/DatasetProvider.ts index d491ef864..2312c9e96 100644 --- a/packages/core/src/integrations/DatasetProvider.ts +++ b/packages/core/src/integrations/DatasetProvider.ts @@ -1,25 +1,5 @@ import { Opaque } from 'type-fest'; -import { ProjectId } from '../index.js'; - -export type DatasetId = Opaque; - -export type DatasetMetadata = { - id: DatasetId; - projectId: ProjectId; - name: string; - description: string; -}; - -export type Dataset = { - id: DatasetId; - rows: DatasetRow[]; -}; - -export type DatasetRow = { - id: string; - - data: string[]; -}; +import { Dataset, DatasetId, DatasetMetadata, DatasetRow, ProjectId } from '../index.js'; export interface DatasetProvider { getDatasetMetadata(id: DatasetId): Promise; @@ -35,4 +15,7 @@ export interface DatasetProvider { clearDatasetData(id: DatasetId): Promise; deleteDataset(id: DatasetId): Promise; + + /** Gets the K nearest neighbor rows to the given vector. */ + knnDatasetRows(datasetId: DatasetId, k: number, vector: number[]): Promise<(DatasetRow & { distance?: number })[]>; } diff --git a/packages/core/src/model/Dataset.ts b/packages/core/src/model/Dataset.ts new file mode 100644 index 000000000..19eb42eaa --- /dev/null +++ b/packages/core/src/model/Dataset.ts @@ -0,0 +1,25 @@ +import { Opaque } from 'type-fest'; +import { ProjectId } from '../index.js'; + +export type DatasetId = Opaque; + +export type DatasetMetadata = { + id: DatasetId; + projectId: ProjectId; + name: string; + description: string; +}; + +export type Dataset = { + id: DatasetId; + rows: DatasetRow[]; +}; + +export type DatasetRow = { + id: string; + + data: string[]; + + /** An optional embedding for the row's data. */ + embedding?: number[]; +}; diff --git a/packages/core/src/model/Nodes.ts b/packages/core/src/model/Nodes.ts index f3d6ec019..9aebcce81 100644 --- a/packages/core/src/model/Nodes.ts +++ b/packages/core/src/model/Nodes.ts @@ -179,6 +179,9 @@ export * from './nodes/GetAllDatasetsNode.js'; import { splitNode } from './nodes/SplitNode.js'; export * from './nodes/SplitNode.js'; +import { datasetNearestNeighborsNode } from './nodes/DatasetNearestNeigborsNode.js'; +export * from './nodes/DatasetNearestNeigborsNode.js'; + export const registerBuiltInNodes = (registry: NodeRegistration) => { return registry .register(toYamlNode) @@ -239,7 +242,8 @@ export const registerBuiltInNodes = (registry: NodeRegistration) => { .register(createDatasetNode) .register(loadDatasetNode) .register(getAllDatasetsNode) - .register(splitNode); + .register(splitNode) + .register(datasetNearestNeighborsNode); }; let globalRivetNodeRegistry = registerBuiltInNodes(new NodeRegistration()); diff --git a/packages/core/src/model/nodes/AppendToDatasetNode.ts b/packages/core/src/model/nodes/AppendToDatasetNode.ts index 6f7e916bd..6aad0e768 100644 --- a/packages/core/src/model/nodes/AppendToDatasetNode.ts +++ b/packages/core/src/model/nodes/AppendToDatasetNode.ts @@ -55,6 +55,12 @@ export class AppendToDatasetNodeImpl extends NodeImpl { title: 'ID', }); + inputDefinitions.push({ + id: 'embedding' as PortId, + dataType: 'vector', + title: 'Embedding', + }); + if (this.data.useDatasetIdInput) { inputDefinitions.push({ id: 'datasetId' as PortId, @@ -111,8 +117,8 @@ export class AppendToDatasetNodeImpl extends NodeImpl { } const datasetId = getInputOrData(this.data, inputs, 'datasetId', 'string') as DatasetId; - const dataId = coerceTypeOptional(inputs['id' as PortId], 'string') || newId(); + const embedding = coerceTypeOptional(inputs['embedding' as PortId], 'vector'); const dataInput = inputs['data' as PortId]; @@ -130,6 +136,7 @@ export class AppendToDatasetNodeImpl extends NodeImpl { newData.push({ id: dataId, data: stringData, + embedding, }); await datasetProvider.putDatasetData(datasetId, { diff --git a/packages/core/src/model/nodes/DatasetNearestNeigborsNode.ts b/packages/core/src/model/nodes/DatasetNearestNeigborsNode.ts new file mode 100644 index 000000000..bb878ee26 --- /dev/null +++ b/packages/core/src/model/nodes/DatasetNearestNeigborsNode.ts @@ -0,0 +1,138 @@ +import { + ChartNode, + DatasetId, + Inputs, + InternalProcessContext, + NodeId, + NodeImpl, + NodeInputDefinition, + NodeOutputDefinition, + NodeUIData, + Outputs, + PortId, + coerceTypeOptional, + dedent, + nodeDefinition, + newId, + EditorDefinition, + coerceType, + getInputOrData, +} from '../../index.js'; + +export type DatasetNearestNeighborsNode = ChartNode<'datasetNearestNeighbors', DatasetNearestNeighborsNodeData>; + +type DatasetNearestNeighborsNodeData = { + datasetId: DatasetId; + useDatasetIdInput?: boolean; + + k: number; + useKInput?: boolean; +}; + +export class DatasetNearestNeighborsNodeImpl extends NodeImpl { + static create(): DatasetNearestNeighborsNode { + return { + id: newId(), + type: 'datasetNearestNeighbors', + title: 'KNN Dataset', + visualData: { x: 0, y: 0, width: 250 }, + data: { + datasetId: '' as DatasetId, + k: 5, + }, + }; + } + + getInputDefinitions(): NodeInputDefinition[] { + const inputs: NodeInputDefinition[] = [ + { + id: 'embedding' as PortId, + title: 'Embedding', + dataType: 'object', + }, + ]; + + if (this.data.useDatasetIdInput) { + inputs.push({ + id: 'datasetId' as PortId, + title: 'Dataset ID', + dataType: 'string', + }); + } + + if (this.data.useKInput) { + inputs.push({ + id: 'k' as PortId, + title: 'K', + dataType: 'number', + }); + } + + return inputs; + } + + getOutputDefinitions(): NodeOutputDefinition[] { + return [ + { + id: 'nearestNeighbors' as PortId, + title: 'Nearest Neighbors', + dataType: 'object[]', + }, + ]; + } + + static getUIData(): NodeUIData { + return { + infoBoxBody: dedent` + Finds the k nearest neighbors in the dataset with the provided ID, given an embedding. + `, + infoBoxTitle: 'KNN Dataset Node', + contextMenuTitle: 'KNN Dataset', + group: ['Input/Output'], + }; + } + + getEditors(): EditorDefinition[] { + return [ + { + type: 'datasetSelector', + label: 'Dataset', + dataKey: 'datasetId', + useInputToggleDataKey: 'useDatasetIdInput', + }, + { + type: 'number', + label: 'K', + dataKey: 'k', + useInputToggleDataKey: 'useKInput', + }, + ]; + } + + async process(inputs: Inputs, context: InternalProcessContext): Promise { + const { datasetProvider } = context; + + if (datasetProvider == null) { + throw new Error('datasetProvider is required'); + } + + const datasetId = getInputOrData(this.data, inputs, 'datasetId'); + const k = getInputOrData(this.data, inputs, 'k', 'number'); + const embedding = coerceType(inputs['embedding' as PortId], 'vector'); + + const nearestNeighbors = await datasetProvider.knnDatasetRows(datasetId as DatasetId, k, embedding); + + return { + ['nearestNeighbors' as PortId]: { + type: 'object[]', + value: nearestNeighbors.map((neighbor) => ({ + id: neighbor.id, + distance: neighbor.distance, + data: neighbor.data, + })), + }, + }; + } +} + +export const datasetNearestNeighborsNode = nodeDefinition(DatasetNearestNeighborsNodeImpl, 'Dataset Nearest Neighbors');