Skip to content

Commit

Permalink
Rework core and engine subgraph handling
Browse files Browse the repository at this point in the history
  • Loading branch information
newcat committed Dec 23, 2023
1 parent f99b200 commit 432d6da
Show file tree
Hide file tree
Showing 15 changed files with 279 additions and 144 deletions.
6 changes: 6 additions & 0 deletions packages/core/src/editor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import { Graph, IGraphState } from "./graph";
import { createGraphNodeType, getGraphNodeTypeString } from "./graphNode";
import { GraphTemplate, IGraphTemplateState } from "./graphTemplate";
import type { AbstractNode, AbstractNodeConstructor } from "./node";
import { GraphInputNode, GraphOutputNode } from "./graphInterface";

export interface IEditorState extends Record<string, any> {
graph: IGraphState;
Expand Down Expand Up @@ -82,6 +83,11 @@ export class Editor implements IBaklavaEventEmitter, IBaklavaTapable {
return this._loading;
}

public constructor() {
this.registerNodeType(GraphInputNode);
this.registerNodeType(GraphOutputNode);
}

/**
* Register a new node type
* @param type Actual type / constructor of the node
Expand Down
46 changes: 32 additions & 14 deletions packages/core/src/graph.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,26 @@ import {
} from "@baklavajs/events";
import { Connection, DummyConnection, IConnection, IConnectionState } from "./connection";
import type { Editor } from "./editor";
import type { GraphTemplate } from "./graphTemplate";
import { type GraphTemplate } from "./graphTemplate";
import type { IAddConnectionEventData } from "./eventDataTypes";
import type { AbstractNode, INodeState } from "./node";
import type { NodeInterface } from "./nodeInterface";

export interface IGraphInterface {
id: string;
nodeInterfaceId: string;
name: string;
}
import {
GRAPH_INPUT_NODE_TYPE,
GRAPH_OUTPUT_NODE_TYPE,
IGraphInterface,
type GraphInputNode,
type GraphOutputNode,
} from "./graphInterface";

export interface IGraphState {
id: string;
nodes: Array<INodeState<unknown, unknown>>;
connections: IConnectionState[];
inputs: IGraphInterface[];
outputs: IGraphInterface[];
/** @deprecated */
inputs: Readonly<IGraphInterface[]>;
/** @deprecated */
outputs: Readonly<IGraphInterface[]>;
}

export interface CheckConnectionHookResult {
Expand All @@ -50,9 +53,6 @@ export class Graph implements IBaklavaEventEmitter, IBaklavaTapable {
public editor: Editor;
public template?: GraphTemplate;

public inputs: IGraphInterface[] = [];
public outputs: IGraphInterface[] = [];

public activeTransactions = 0;

protected _nodes: AbstractNode[] = [];
Expand Down Expand Up @@ -103,6 +103,26 @@ export class Graph implements IBaklavaEventEmitter, IBaklavaTapable {
return this._destroying;
}

public get inputs(): IGraphInterface[] {
const inputNodes = this.nodes.filter((n) => n.type === GRAPH_INPUT_NODE_TYPE) as GraphInputNode[];
return inputNodes.map((n) => ({
id: n.graphInterfaceId,
name: n.inputs.name.value,
nodeId: n.id,
nodeInterfaceId: n.outputs.placeholder.id,
}));
}

public get outputs(): IGraphInterface[] {
const outputNodes = this.nodes.filter((n) => n.type === GRAPH_OUTPUT_NODE_TYPE) as GraphOutputNode[];
return outputNodes.map((n) => ({
id: n.graphInterfaceId,
name: n.inputs.name.value,
nodeId: n.id,
nodeInterfaceId: n.outputs.output.id,
}));
}

public constructor(editor: Editor, template?: GraphTemplate) {
this.editor = editor;
this.template = template;
Expand Down Expand Up @@ -300,8 +320,6 @@ export class Graph implements IBaklavaEventEmitter, IBaklavaTapable {

// Load state
this.id = state.id;
this.inputs = state.inputs;
this.outputs = state.outputs;

for (const n of state.nodes) {
// find node type
Expand Down
77 changes: 77 additions & 0 deletions packages/core/src/graphInterface.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import { v4 as uuidv4 } from "uuid";
import { Node, INodeState, CalculateFunction, AbstractNode } from "./node";
import { NodeInterface } from "./nodeInterface";

export interface IGraphInterface {
id: string;
nodeId: string;
nodeInterfaceId: string;
name: string;
}

export const GRAPH_INPUT_NODE_TYPE = "__baklava_SubgraphInputNode";
export const GRAPH_OUTPUT_NODE_TYPE = "__baklava_SubgraphOutputNode";

interface IGraphInterfaceNodeState<I, O> extends INodeState<I, O> {
graphInterfaceId: string;
}

abstract class GraphInterfaceNode<I, O> extends Node<I, O> {
public graphInterfaceId: string;

constructor() {
super();
this.graphInterfaceId = uuidv4();
}

onPlaced() {
super.onPlaced();
this.initializeIo();
}

save(): IGraphInterfaceNodeState<I, O> {
return {
...super.save(),
graphInterfaceId: this.graphInterfaceId,
};
}

load(state: IGraphInterfaceNodeState<I, O>) {
super.load(state as INodeState<I, O>);
this.graphInterfaceId = state.graphInterfaceId;
}
}

export class GraphInputNode extends GraphInterfaceNode<{ name: string }, { placeholder: any }> {
public static isGraphInputNode(v: AbstractNode): v is GraphInputNode {
return v.type === GRAPH_INPUT_NODE_TYPE;
}

public override readonly type = GRAPH_INPUT_NODE_TYPE;
public inputs = {
name: new NodeInterface("Name", "Input"),
};
public outputs = {
placeholder: new NodeInterface("Value", undefined),
};
}
export type GraphInputNodeState = IGraphInterfaceNodeState<{ name: string }, { placeholder: any }>;

export class GraphOutputNode extends GraphInterfaceNode<{ name: string; placeholder: any }, { output: any }> {
public static isGraphOutputNode(v: AbstractNode): v is GraphOutputNode {
return v.type === GRAPH_OUTPUT_NODE_TYPE;
}

public override readonly type = GRAPH_OUTPUT_NODE_TYPE;
public inputs = {
name: new NodeInterface("Name", "Output"),
placeholder: new NodeInterface("Value", undefined),
};
public outputs = {
output: new NodeInterface("Output", undefined).setHidden(true),
};
public override calculate: CalculateFunction<{ placeholder: any }, { output: any }> = ({ placeholder }) => ({
output: placeholder,
});
}
export type GraphOutputNodeState = IGraphInterfaceNodeState<{ name: string; placeholder: any }, { output: any }>;
23 changes: 5 additions & 18 deletions packages/core/src/graphNode.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,4 @@
import {
GRAPH_TEMPLATE_INPUT_NODE_TYPE,
GraphTemplateInputNode,
type GraphTemplate,
GRAPH_TEMPLATE_OUTPUT_NODE_TYPE,
GraphTemplateOutputNode,
} from "./graphTemplate";
import { type GraphTemplate } from "./graphTemplate";
import { Graph, IGraphState } from "./graph";
import { AbstractNode, CalculateFunction, INodeState } from "./node";
import { NodeInterface } from "./nodeInterface";
Expand Down Expand Up @@ -55,11 +49,8 @@ export function createGraphNodeType(template: GraphTemplate): new () => Abstract
const graphInputs = context.engine.getInputValues(this.subgraph);

// fill subgraph input placeholders
const inputNodes = this.subgraph.nodes.filter(
(n) => n.type === GRAPH_TEMPLATE_INPUT_NODE_TYPE,
) as GraphTemplateInputNode[];
for (const inputNode of inputNodes) {
graphInputs.set(inputNode.outputs.placeholder.id, inputs[inputNode.graphInterfaceId]);
for (const input of this.subgraph.inputs) {
graphInputs.set(input.nodeInterfaceId, inputs[input.id]);
}

const result: Map<string, Map<string, any>> = await context.engine.runGraph(
Expand All @@ -69,12 +60,8 @@ export function createGraphNodeType(template: GraphTemplate): new () => Abstract
);

const outputs: Record<string, any> = {};
const outputNodes = this.subgraph.nodes.filter(
(n) => n.type === GRAPH_TEMPLATE_OUTPUT_NODE_TYPE,
) as unknown as GraphTemplateOutputNode[];
for (const outputNode of outputNodes) {
console.log("Output node ID", outputNode.id);
outputs[outputNode.graphInterfaceId] = result.get(outputNode.id)?.get("output");
for (const output of this.subgraph.outputs) {
outputs[output.id] = result.get(output.nodeId)?.get("output");
}

outputs._calculationResults = result;
Expand Down
97 changes: 36 additions & 61 deletions packages/core/src/graphTemplate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,64 +2,21 @@ import { v4 as uuidv4 } from "uuid";
import { BaklavaEvent, SequentialHook } from "@baklavajs/events";
import type { IConnectionState } from "./connection";
import type { Editor } from "./editor";
import { Graph, IGraphInterface, IGraphState } from "./graph";
import { Node, type INodeState, CalculateFunction } from "./node";
import { Graph, IGraphState } from "./graph";
import type { INodeState } from "./node";
import type { INodeInterfaceState } from "./nodeInterface";
import { mapValues } from "./utils";
import { getGraphNodeTypeString } from "./graphNode";
import {
GRAPH_INPUT_NODE_TYPE,
GRAPH_OUTPUT_NODE_TYPE,
GraphInputNodeState,
GraphOutputNodeState,
IGraphInterface,
} from "./graphInterface";

type Optional<T, K extends keyof T> = Partial<Pick<T, K>> & Omit<T, K>;

export const GRAPH_TEMPLATE_INPUT_NODE_TYPE = "__baklava_SubgraphInputNode";
export const GRAPH_TEMPLATE_OUTPUT_NODE_TYPE = "__baklava_SubgraphOutputNode";

interface IGraphTemplateInterfaceState<I, O> extends INodeState<I, O> {
graphInterfaceId: string;
}

abstract class GraphTemplateInterfaceNode<I, O> extends Node<I, O> {
public graphInterfaceId: string;

constructor() {
super();
this.graphInterfaceId = uuidv4();
}

onPlaced() {
super.onPlaced();
this.initializeIo();
}

save(): IGraphTemplateInterfaceState<I, O> {
return {
...super.save(),
graphInterfaceId: this.graphInterfaceId,
};
}

load(state: IGraphTemplateInterfaceState<I, O>) {
super.load(state as INodeState<I, O>);
this.graphInterfaceId = state.graphInterfaceId;
}
}

export abstract class GraphTemplateInputNode extends GraphTemplateInterfaceNode<
{ name: string },
{ placeholder: any }
> {
public override readonly type = GRAPH_TEMPLATE_INPUT_NODE_TYPE;
}

export abstract class GraphTemplateOutputNode extends GraphTemplateInterfaceNode<
{ name: string; placeholder: any },
{ output: any }
> {
public override readonly type = GRAPH_TEMPLATE_OUTPUT_NODE_TYPE;
public override calculate: CalculateFunction<{ placeholder: any }, { output: any }> = ({ placeholder }) => ({
output: placeholder,
});
}

export interface IGraphTemplateState extends IGraphState {
name: string;
}
Expand All @@ -79,12 +36,6 @@ export class GraphTemplate implements IGraphState {
/** List of all connection states in this graph template */
public connections!: IConnectionState[];

/** List of all inputs to the graph template */
public inputs!: IGraphInterface[];

/** List of all outputs of the graph template */
public outputs!: IGraphInterface[];

/** Editor instance */
public editor: Editor;

Expand All @@ -105,6 +56,28 @@ export class GraphTemplate implements IGraphState {
}
}

/** List of all inputs to the graph template */
public get inputs(): Readonly<IGraphInterface[]> {
const inputNodes = this.nodes.filter((n) => n.type === GRAPH_INPUT_NODE_TYPE) as GraphInputNodeState[];
return inputNodes.map((n) => ({
id: n.graphInterfaceId,
name: n.inputs.name.value,
nodeId: n.id,
nodeInterfaceId: n.outputs.placeholder.id,
}));
}

/** List of all outputs of the graph template */
public get outputs(): Readonly<IGraphInterface[]> {
const outputNodes = this.nodes.filter((n) => n.type === GRAPH_OUTPUT_NODE_TYPE) as GraphOutputNodeState[];
return outputNodes.map((n) => ({
id: n.graphInterfaceId,
name: n.inputs.name.value,
nodeId: n.id,
nodeInterfaceId: n.outputs.output.id,
}));
}

constructor(state: Optional<IGraphTemplateState, "id" | "name">, editor: Editor) {
this.editor = editor;
if (state.id) {
Expand All @@ -130,8 +103,6 @@ export class GraphTemplate implements IGraphState {
public update(state: Omit<IGraphState, "id">) {
this.nodes = state.nodes;
this.connections = state.connections;
this.inputs = state.inputs;
this.outputs = state.outputs;
this.events.updated.emit();
}

Expand Down Expand Up @@ -194,12 +165,14 @@ export class GraphTemplate implements IGraphState {
const inputs: IGraphInterface[] = this.inputs.map((i) => ({
id: i.id,
name: i.name,
nodeId: getNewId(i.nodeId),
nodeInterfaceId: getNewId(i.nodeInterfaceId),
}));

const outputs: IGraphInterface[] = this.outputs.map((o) => ({
id: o.id,
name: o.name,
nodeId: getNewId(o.nodeId),
nodeInterfaceId: getNewId(o.nodeInterfaceId),
}));

Expand All @@ -214,7 +187,9 @@ export class GraphTemplate implements IGraphState {
if (!graph) {
graph = new Graph(this.editor);
}
graph.load(clonedState);
const warnings = graph.load(clonedState);
warnings.forEach((w) => console.warn(w));

graph.template = this;
return graph;
}
Expand Down
2 changes: 2 additions & 0 deletions packages/core/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@ export * from "./connection";
export * from "./defineNode";
export * from "./dynamicNode";
export * from "./editor";
export * from "./engine";
export * from "./eventDataTypes";
export * from "./graph";
export * from "./graphInterface";
export * from "./graphNode";
export * from "./graphTemplate";
export * from "./node";
Expand Down
4 changes: 2 additions & 2 deletions packages/core/test/defineNode.spec.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { defineNode, NodeInterface } from "../src";
import { defineNode, IEngine, NodeInterface } from "../src";

describe("defineNode", () => {
it("calls the onCreate lifecycle method correctly", () => {
Expand Down Expand Up @@ -63,7 +63,7 @@ describe("defineNode", () => {
calculate: calculateSpy,
});
const n = new TestNode();
const result = n.calculate!({ a: 4 }, { globalValues: { test: true }, engine: {} });
const result = n.calculate!({ a: 4 }, { globalValues: { test: true }, engine: {} as IEngine<void> });
expect(result).toEqual({ b: "5" });
expect(calculateSpy).toHaveBeenCalledWith({ a: 4 }, { globalValues: { test: true }, engine: {} });
});
Expand Down
Loading

0 comments on commit 432d6da

Please sign in to comment.