Skip to content

Commit

Permalink
More data studio work, knn node,
Browse files Browse the repository at this point in the history
  • Loading branch information
abrenneke committed Sep 19, 2023
1 parent 5e18cb8 commit 39e3eb4
Show file tree
Hide file tree
Showing 11 changed files with 268 additions and 31 deletions.
18 changes: 17 additions & 1 deletion packages/app/src/components/CodeEditor.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,21 @@ export const CodeEditor: FC<{
theme?: string;
autoFocus?: boolean;
onKeyDown?: (e: monaco.IKeyboardEvent) => void;
onBlur?: () => void;
editorRef?: MutableRefObject<monaco.editor.IStandaloneCodeEditor | undefined>;
scrollBeyondLastLine?: boolean;
}> = ({ text, isReadonly, onChange, language, theme, autoFocus, onKeyDown, editorRef, scrollBeyondLastLine }) => {
}> = ({
text,
isReadonly,
onChange,
language,
theme,
autoFocus,
onKeyDown,
onBlur,
editorRef,
scrollBeyondLastLine,
}) => {
const editorContainer = useRef<HTMLDivElement>(null);
const editorInstance = useRef<monaco.editor.IStandaloneCodeEditor>();

Expand Down Expand Up @@ -56,6 +68,10 @@ export const CodeEditor: FC<{
onChangeLatest.current?.(editor.getValue());
});

editor.onDidBlurEditorWidget(() => {
onBlur?.();
});

editorInstance.current = editor;
if (editorRef) {
editorRef.current = editor;
Expand Down
25 changes: 18 additions & 7 deletions packages/app/src/components/DataStudio.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import { parse as parseCsv } from 'csv-parse/browser/esm/sync';
import { ioProvider } from '../utils/globals';
import { useDataset } from '../hooks/useDataset';
import { useDatasets } from '../hooks/useDatasets';
import { LazyCodeEditor } from './LazyComponents';

export const DataStudioRenderer: FC = () => {
const [openOverlay, setOpenOverlay] = useRecoilState(overlayOpenState);
Expand Down Expand Up @@ -326,13 +327,21 @@ const datasetDisplayStyles = css`
.cell {
height: 48px;
&.editor {
.editor-container {
min-height: 200px;
min-width: 100px;
}
}
}
.value {
padding: 4px 8px;
height: 100%;
display: flex;
align-items: center;
align-items: flex-start;
overflow: hidden;
}
}
}
Expand Down Expand Up @@ -503,19 +512,21 @@ const DatasetEditableCell: FC<{
onChange: (value: string) => void;
}> = ({ value, row, column, onChange }) => {
const [editing, setEditing] = useState(false);
const [editingText, setEditingText] = useState(value);

return (
<div className="cell" data-row={row} data-column={column} data-contextmenutype="cell">
<div className="cell editor" data-row={row} data-column={column} data-contextmenutype="cell">
{editing ? (
<TextField
<LazyCodeEditor
autoFocus
defaultValue={value}
onBlur={(e) => {
onChange((e.target as HTMLInputElement).value);
text={value}
onChange={(e) => setEditingText(e)}
onBlur={() => {
onChange(editingText);
setEditing(false);
}}
onKeyDown={(e) => {
if (e.key === 'Enter') {
if (e.keyCode === 3 && (e.metaKey || e.ctrlKey)) {
onChange((e.target as HTMLInputElement).value);
setEditing(false);
}
Expand Down
27 changes: 27 additions & 0 deletions packages/app/src/components/nodes/DatasetNearestNeighborsNode.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import { DatasetNearestNeighborsNode } from '@ironclad/rivet-core';
import { NodeComponentDescriptor } from '../../hooks/useNodeTypes';
import { FC } from 'react';
import { useDatasets } from '../../hooks/useDatasets';
import { useRecoilValue } from 'recoil';
import { projectState } from '../../state/savedGraphs';

export const DatasetNearestNeightborsNode: FC<{
node: DatasetNearestNeighborsNode;
}> = ({ node }) => {
const project = useRecoilValue(projectState);
const { datasets } = useDatasets(project.metadata.id);

const dataset = datasets?.find((d) => d.id === node.data.datasetId);

return (
<div>
<div>
{node.data.useDatasetIdInput ? 'Dataset from input' : dataset ? dataset.name : 'Unknown or no dataset selected'}
</div>
</div>
);
};

export const datasetNearestNeighborsNodeDescriptor: NodeComponentDescriptor<'datasetNearestNeighbors'> = {
Body: DatasetNearestNeightborsNode,
};
2 changes: 2 additions & 0 deletions packages/app/src/hooks/useNodeTypes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import { appendToDatasetNodeDescriptor } from '../components/nodes/AppendToDatas
import { useRecoilValue } from 'recoil';
import { pluginRefreshCounterState } from '../state/plugins';
import { loadDatasetNodeDescriptor } from '../components/nodes/LoadDatasetNode';
import { datasetNearestNeighborsNodeDescriptor } from '../components/nodes/DatasetNearestNeighborsNode';

export type UnknownNodeComponentDescriptor = {
Body?: FC<{ node: ChartNode }>;
Expand Down Expand Up @@ -55,6 +56,7 @@ const overriddenDescriptors: Partial<NodeComponentDescriptors> = {
audio: audioNodeDescriptor,
appendToDataset: appendToDatasetNodeDescriptor,
loadDataset: loadDatasetNodeDescriptor,
datasetNearestNeighbors: datasetNearestNeighborsNodeDescriptor,
};

export function useNodeTypes(): NodeComponentDescriptors {
Expand Down
23 changes: 23 additions & 0 deletions packages/app/src/io/BrowserDatasetProvider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -137,4 +137,27 @@ export class BrowserDatasetProvider implements DatasetProvider {
dataRequest.onerror = reject;
});
}

async knnDatasetRows(
datasetId: DatasetId,
k: number,
vector: number[],
): Promise<(DatasetRow & { distance?: number })[]> {
const allRows = await this.getDatasetData(datasetId);

const sorted = allRows.rows
.filter((row) => row.embedding != null)
.map((row) => ({
row,
similarity: dotProductSimilarity(vector, row.embedding!),
}))
.sort((a, b) => b.similarity - a.similarity);

return sorted.slice(0, k).map((r) => ({ ...r.row, distance: r.similarity }));
}
}

/** OpenAI embeddings are already normalized, so this is equivalent to cosine similarity */
const dotProductSimilarity = (a: number[], b: number[]): number => {
return a.reduce((acc, val, i) => acc + val * b[i]!, 0);
};
1 change: 1 addition & 0 deletions packages/core/src/exports.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ export * from './utils/inputs.js';
export * from './utils/newId.js';
export * from './utils/misc.js';
export * from './integrations/DatasetProvider.js';
export * from './model/Dataset.js';

import * as openai from './utils/openai.js';
export { openai };
25 changes: 4 additions & 21 deletions packages/core/src/integrations/DatasetProvider.ts
Original file line number Diff line number Diff line change
@@ -1,25 +1,5 @@
import { Opaque } from 'type-fest';
import { ProjectId } from '../index.js';

export type DatasetId = Opaque<string, 'DatasetId'>;

export type DatasetMetadata = {
id: DatasetId;
projectId: ProjectId;
name: string;
description: string;
};

export type Dataset = {
id: DatasetId;
rows: DatasetRow[];
};

export type DatasetRow = {
id: string;

data: string[];
};
import { Dataset, DatasetId, DatasetMetadata, DatasetRow, ProjectId } from '../index.js';

export interface DatasetProvider {
getDatasetMetadata(id: DatasetId): Promise<DatasetMetadata[]>;
Expand All @@ -35,4 +15,7 @@ export interface DatasetProvider {
clearDatasetData(id: DatasetId): Promise<void>;

deleteDataset(id: DatasetId): Promise<void>;

/** Gets the K nearest neighbor rows to the given vector. */
knnDatasetRows(datasetId: DatasetId, k: number, vector: number[]): Promise<(DatasetRow & { distance?: number })[]>;
}
25 changes: 25 additions & 0 deletions packages/core/src/model/Dataset.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import { Opaque } from 'type-fest';
import { ProjectId } from '../index.js';

export type DatasetId = Opaque<string, 'DatasetId'>;

export type DatasetMetadata = {
id: DatasetId;
projectId: ProjectId;
name: string;
description: string;
};

export type Dataset = {
id: DatasetId;
rows: DatasetRow[];
};

export type DatasetRow = {
id: string;

data: string[];

/** An optional embedding for the row's data. */
embedding?: number[];
};
6 changes: 5 additions & 1 deletion packages/core/src/model/Nodes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,9 @@ export * from './nodes/GetAllDatasetsNode.js';
import { splitNode } from './nodes/SplitNode.js';
export * from './nodes/SplitNode.js';

import { datasetNearestNeighborsNode } from './nodes/DatasetNearestNeigborsNode.js';
export * from './nodes/DatasetNearestNeigborsNode.js';

export const registerBuiltInNodes = (registry: NodeRegistration) => {
return registry
.register(toYamlNode)
Expand Down Expand Up @@ -239,7 +242,8 @@ export const registerBuiltInNodes = (registry: NodeRegistration) => {
.register(createDatasetNode)
.register(loadDatasetNode)
.register(getAllDatasetsNode)
.register(splitNode);
.register(splitNode)
.register(datasetNearestNeighborsNode);
};

let globalRivetNodeRegistry = registerBuiltInNodes(new NodeRegistration());
Expand Down
9 changes: 8 additions & 1 deletion packages/core/src/model/nodes/AppendToDatasetNode.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ export class AppendToDatasetNodeImpl extends NodeImpl<AppendToDatasetNode> {
title: 'ID',
});

inputDefinitions.push({
id: 'embedding' as PortId,
dataType: 'vector',
title: 'Embedding',
});

if (this.data.useDatasetIdInput) {
inputDefinitions.push({
id: 'datasetId' as PortId,
Expand Down Expand Up @@ -111,8 +117,8 @@ export class AppendToDatasetNodeImpl extends NodeImpl<AppendToDatasetNode> {
}

const datasetId = getInputOrData(this.data, inputs, 'datasetId', 'string') as DatasetId;

const dataId = coerceTypeOptional(inputs['id' as PortId], 'string') || newId<DatasetId>();
const embedding = coerceTypeOptional(inputs['embedding' as PortId], 'vector');

const dataInput = inputs['data' as PortId];

Expand All @@ -130,6 +136,7 @@ export class AppendToDatasetNodeImpl extends NodeImpl<AppendToDatasetNode> {
newData.push({
id: dataId,
data: stringData,
embedding,
});

await datasetProvider.putDatasetData(datasetId, {
Expand Down
Loading

0 comments on commit 39e3eb4

Please sign in to comment.