diff --git a/src/recurrent.ts b/src/recurrent.ts index 28482875..d4de4408 100644 --- a/src/recurrent.ts +++ b/src/recurrent.ts @@ -426,4 +426,15 @@ export class Recurrent< } return null; } + + fromJSON(json: any): void { + super.fromJSON(json); + this._layerSets = json.layerSets.map((layerSet: any) => + layerSet.map((layer: any) => { + const newLayer = new (layer.constructor as any)(); + newLayer.fromJSON(layer); + return newLayer; + }) + ); + } } diff --git a/src/recurrent/lstm.test.ts b/src/recurrent/lstm.test.ts index 90eeb417..be27b671 100644 --- a/src/recurrent/lstm.test.ts +++ b/src/recurrent/lstm.test.ts @@ -192,4 +192,51 @@ describe('LSTM', () => { expect(net.run([transactionTypes.other])).toBe('other'); }); }); + + describe('cloned LSTM net training', () => { + it('continues evolving from the point where the original stopped', () => { + const net = new LSTM({ hiddenLayers: [60, 60] }); + net.maxPredictionLength = 100; + + const trainData = [ + 'doe, a deer, a female deer', + 'ray, a drop of golden sun', + 'me, a name I call myself', + ]; + + // First train + net.train(trainData, { + iterations: 5000, + log: true, + logPeriod: 500, + learningRate: 0.2, + }); + + // Clone the net: + const net2 = new LSTM({ hiddenLayers: [60, 60] }); + net2.fromJSON(net.toJSON()); + + // Both output the same text: + expect(net.run('ray')).toBe(net2.run('ray')); + + // More training, start from the last error rate: + net.train(trainData, { + iterations: 30, + log: true, + logPeriod: 10, + learningRate: 0.2, + }); + + // More training to the clone: + net2.train(trainData, { + iterations: 30, + log: true, + logPeriod: 10, + learningRate: 0.2, + }); + + // The first reduced the quality, but the second is crazy: + expect(net.run('ray')).not.toBe(net2.run('ray')); + }); + }); });