Skip to content

Commit

Permalink
Run updateRAM on both CPU and GPU nets
Browse files Browse the repository at this point in the history
  • Loading branch information
voidvoxel committed Jun 17, 2024
1 parent b703e4a commit fda0349
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 52 deletions.
37 changes: 26 additions & 11 deletions src/neural-network-gpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,22 @@ export class NeuralNetworkGPU<
if (!value) {
if (this._ramKernel) delete this._ramKernel;
}
else this._ramKernel = this.gpu.createKernel(value);
else {
const layerCount = this.sizes.length;
const maxNeuronsPerLayer = this.sizes.reduce(
(eax, edx) => edx > eax ? edx : eax
);
const ramSize = this.ramSize;
this._ramKernel = this.gpu.createKernel(
value,
{
constants: {
ramSize
},
output: [ layerCount, maxNeuronsPerLayer, ramSize ]
}
);
}
super.ramFunction = value;
}

Expand Down Expand Up @@ -380,16 +395,6 @@ export class NeuralNetworkGPU<
});
}

const updateRAM: IKernelRunShortcut | undefined = this._ramKernel;

if (updateRAM) {
const input = this.outputs[0];
const output = this.outputs[this.outputLayer];
const loss = this.loss.current.mean;
const deltaLoss = loss - this.loss.previous.mean;
updateRAM(this.ram, this.ramSize, input, output, this.sizes, loss, deltaLoss);
}

this.texturizeInputData = this.gpu.createKernel(
function (value: number[]): number {
return value[this.thread.x];
Expand All @@ -416,6 +421,16 @@ export class NeuralNetworkGPU<
);
output = input = this.outputs[layer];
}
const updateRAM: IKernelRunShortcut | undefined = this._ramKernel;
if (updateRAM) {
const input = this.outputs[0];
const output = this.outputs[this.outputLayer];
const loss = this.loss.current.mean;
const deltaLoss = loss - this.loss.previous.mean;
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
// @ts-ignore
this._ram = updateRAM(this.ram, input, output, this.sizes, loss, deltaLoss);
}
return output;
};

Expand Down
93 changes: 52 additions & 41 deletions src/neural-network.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ export type LossFunction = (
export type RAMFunction = (
this: IKernelFunctionThis,
ram: NeuralNetworkRAM,
ramSize: number,
inputs: NeuralNetworkIO,
outputs: NeuralNetworkIO,
sizes: number[],
Expand All @@ -44,9 +43,9 @@ export interface ILossAnalyticsSnapshot {
}

const EMPTY_LOSS_SNAPSHOT: ILossAnalyticsSnapshot = {
mean: NaN,
median: NaN,
total: NaN
mean: Number.MAX_SAFE_INTEGER,
median: Number.MAX_SAFE_INTEGER,
total: Number.MAX_SAFE_INTEGER
};

Object.freeze(EMPTY_LOSS_SNAPSHOT);
Expand Down Expand Up @@ -263,7 +262,9 @@ export class NeuralNetwork<

runInput: (input: Float32Array) => Float32Array = (input: Float32Array) => {
this.setActivation();
return this.runInput(input);
const output = this.runInput(input);
this._updateRAM();
return output;
};

calculateDeltas: (output: Float32Array, input: Float32Array) => void = (
Expand Down Expand Up @@ -301,6 +302,7 @@ export class NeuralNetwork<
// Initialize the loss function.
this._lossAnalytics = createLossAnalytics();
if (options.loss) this._lossFunction = options.loss;
if (options.updateRAM) this._ramFunction = options.updateRAM;
}

/**
Expand Down Expand Up @@ -426,6 +428,7 @@ export class NeuralNetwork<
}
this.validateInput(formattedInput);
const output = this.runInput(formattedInput).slice(0);
this._updateRAM();
if (this.outputLookup) {
return (lookup.toObject(
this.outputLookup,
Expand All @@ -435,6 +438,47 @@ export class NeuralNetwork<
return (output as unknown) as OutputType;
}

protected _updateRAM() {
if (this.ram) {
const updateRAM: RAMFunction | undefined = this.ramFunction;

if (updateRAM) {
const ramSize = this.ramSize;
const input = this.outputs[0];
const output = this.outputs[this.outputLayer];
const loss = this.loss.current.mean;
const deltaLoss = loss - this.loss.previous.mean;
this._ram = this.ram.map(
(layerRAM, layer) => layerRAM.map(
(neuronRAM, neuron) => neuronRAM.map(
(value, index) => {
const kernelFunctionThis: IKernelFunctionThis = {
color: function color(r: number, g: number = 0, b: number = 0, a: number = 0) {},
constants: {
ramSize
},
output: {
x: NaN,
y: NaN,
z: NaN
},
thread: {
x: index,
y: neuron,
z: layer
}
};
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
// @ts-ignore
return updateRAM.call(kernelFunctionThis, this.ram, this.ramSize, input, output, this.sizes, loss, deltaLoss);
}
)
)
);
}
}
}

_runInputSigmoid(input: Float32Array): Float32Array {
this.outputs[0] = input; // set output state of input layer

Expand Down Expand Up @@ -852,47 +896,12 @@ export class NeuralNetwork<
): number | null {
// forward propagate
this.runInput(value.input);
this._updateRAM();

// back propagate
this.calculateDeltas(value.output, value.input);
this.adjustWeights();

if (this.ram) {
const updateRAM: RAMFunction | undefined = this.ramFunction;

if (updateRAM) {
const input = this.outputs[0];
const output = this.outputs[this.outputLayer];
const loss = this.loss.current.mean;
const deltaLoss = loss - this.loss.previous.mean;
this.ram.map(
(layerRAM, layer) => layerRAM.map(
(neuronRAM, neuron) => neuronRAM.map(
(value, index) => {
const kernelFunctionThis: IKernelFunctionThis = {
color: function color(r: number, g: number = 0, b: number = 0, a: number = 0) {},
constants: {},
output: {
x: NaN,
y: NaN,
z: NaN
},
thread: {
x: index,
y: neuron,
z: layer
}
};
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
// @ts-ignore
return updateRAM.call(kernelFunctionThis, this.ram, this.ramSize, input, output, this.sizes, loss, deltaLoss);
}
)
)
)
}
}

if (logErrorRate) {
return mse(this.errors[this.outputLayer]);
}
Expand Down Expand Up @@ -1290,6 +1299,7 @@ export class NeuralNetwork<

for (let i = 0; i < preparedData.length; i++) {
const output = this.runInput(preparedData[i].input);
this._updateRAM();
const target = preparedData[i].output;
const actual = output[0] > this.options.binaryThresh ? 1 : 0;
const expected = target[0];
Expand Down Expand Up @@ -1337,6 +1347,7 @@ export class NeuralNetwork<

for (let i = 0; i < preparedData.length; i++) {
const output = this.runInput(preparedData[i].input);
this._updateRAM();
const target = preparedData[i].output;
const actual = output.indexOf(max(output));
const expected = target.indexOf(max(target));
Expand Down

0 comments on commit fda0349

Please sign in to comment.