From bb364734ba4b16dadee18585e10f3e728ac0a913 Mon Sep 17 00:00:00 2001 From: karita Date: Tue, 19 Jun 2018 22:09:08 +0900 Subject: [PATCH] refactor optim --- example/char_rnn.d | 5 ++-- example/mnist.d | 4 ++-- source/grain/optim.d | 55 +++++++++++++++++++++++++++----------------- 3 files changed, 38 insertions(+), 26 deletions(-) diff --git a/example/char_rnn.d b/example/char_rnn.d index 1c29c27..b887ca7 100644 --- a/example/char_rnn.d +++ b/example/char_rnn.d @@ -130,7 +130,6 @@ void main() { auto model = RNN!Storage(vocabSize, hiddenSize); // for optim - // SGD optim = { lr: learningRate }; auto optim = AdaGrad!(typeof(model))(model, learningRate); auto smoothLoss = -log(1.0 / vocabSize) * seqLength; size_t beginId = 0; @@ -154,9 +153,9 @@ void main() { // forward seq_length characters through the net and fetch gradient model.zeroGrad(); auto ret = model.accumGrad(ids.sliced.unsqueeze!0, hprev); - hprev = ret.hprev; - optim.update(model); + optim.update(); smoothLoss = smoothLoss * 0.999 + ret.loss * 0.001; + hprev = ret.hprev; if (nIter % logIter == 0) { writefln!"iter %d, loss: %f, iter/sec: %f"( nIter, smoothLoss, diff --git a/example/mnist.d b/example/mnist.d index 25218b0..0ea808a 100644 --- a/example/mnist.d +++ b/example/mnist.d @@ -118,7 +118,7 @@ void main() { auto trainBatch = datasets.train.makeBatch(batchSize); auto testBatch = datasets.test.makeBatch(batchSize); auto model = MLP!(float, S)(inSize, 512, 10); - SGD optimizer = {lr: 1e-2}; + auto optimizer = SGD!(typeof(model))(model, 1e-2); foreach (epoch; 0 .. 10) { // TODO implement model.train(); @@ -135,7 +135,7 @@ void main() { accSum += acc; model.zeroGrad(); loss.backward(); - optimizer.update(model); + optimizer.update(); } writefln!"train loss: %f, acc: %f"(lossSum / niter, accSum / niter); } diff --git a/source/grain/optim.d b/source/grain/optim.d index 67f1327..4a8c389 100644 --- a/source/grain/optim.d +++ b/source/grain/optim.d @@ -33,7 +33,7 @@ enum bool isOptimizer(T) = is(typeof({ /// kind of std.algorithm.each for iterating variables inside a chain -void iterVariables(alias proc, C)(ref C chain, string prefix="") { +void iterVariables(alias proc, C)(C* chain, string prefix="") { import std.traits; import grain.autograd; @@ -44,11 +44,12 @@ void iterVariables(alias proc, C)(ref C chain, string prefix="") { static if (isVariable!V) { proc(fullName, value); } else static if (hasMember!(V, "tupleof")) { - iterVariables!proc(value, fullName); + iterVariables!proc(&value, fullName); } } } +/* enum variableNames(C) = { string[] ret; void register(V)(string k, V v) if (isVariable!V) { @@ -70,12 +71,13 @@ unittest { StateDict dict; iterVariables!( (k, v) { dict[k] = UntypedVariable(v); } )(mlp); } +*/ alias StateDict = UntypedVariable[string]; -void update(O, C)(ref O optimizer, ref C chain, string attr = "") { // if (isOptimizer!O) { - iterVariables!( (k, v) {optimizer.step(k, v);} )(chain); +void update(O)(ref O optimizer) { // if (isOptimizer!O) { + iterVariables!( (k, v) {optimizer.step(k, v);} )(optimizer.target, ""); } void transform(T, size_t dim)(Variable!(T, dim, HostStorage) src, ref Variable!(T, dim, HostStorage) dst, T alpha=1, T beta=0) { @@ -89,10 +91,15 @@ void transform(T, size_t dim)(Variable!(T, dim, HostStorage) src, ref Variable!( /// stochastic gradient descent optimizer -struct SGD { +struct SGD(Chain) { + Chain* target; float lr = 1.0; // float momentum = 0.0; // float weightDecay = 0.0; + this(ref Chain target, float lr=1.0) { + this.target = ⌖ + this.lr = lr; + } void step(V)(string name, ref V field) if (isVariable!V) { // transform(field.gradVariable, field, -this.lr, 1.0); @@ -143,10 +150,10 @@ unittest { mlp.zeroGrad(); assert(mlp.fc1.weight.grad[0] == 0.0); - auto sgd = SGD(0.5); + auto sgd = SGD!(typeof(mlp))(mlp, 0.5); mlp.fc1.weight.data.zero_(); mlp.fc1.weight.grad = [[1.0f, 0.0f, 0.0f], [0.0f, 0.0f, 0.0f]].variable.data; - sgd.update(mlp); + sgd.update(); assert(mlp.fc1.weight.sliced == [[-0.5, 0.0, 0.0], [0.0, 0.0, 0.0]]); } version (grain_cuda) { @@ -155,10 +162,10 @@ unittest { mlp.zeroGrad(); assert(mlp.fc1.weight.to!HostStorage.gradSliced == [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]); - auto sgd = SGD(0.5); + auto sgd = SGD!(typeof(mlp))(mlp, 0.5); mlp.fc1.weight.data.zero_(); mlp.fc1.weight.grad = [[1.0f, 0.0f, 0.0f], [0.0f, 0.0f, 0.0f]].variable.to!DeviceStorage.data; - sgd.update(mlp); + sgd.update(); assert(mlp.fc1.weight.to!HostStorage.sliced == [[-0.5, 0.0, 0.0], [0.0, 0.0, 0.0]]); } } @@ -168,14 +175,16 @@ unittest { struct AdaGrad(Chain) { import grain.autograd; + Chain* target; float lr = 1.0; float eps = 1e-8; StateDict memory; - this(ref Chain model, float lr=1e-3, float eps=1e-8) { + this(ref Chain target, float lr=1e-3, float eps=1e-8) { + this.target = ⌖ this.lr = lr; this.eps = eps; - iterVariables!((k, v) { this.initStates(k, v); })(model); + iterVariables!((k, v) { this.initStates(k, v); })(this.target); } void initStates(V)(string name, ref V field) if (isVariable!V) { @@ -210,7 +219,7 @@ unittest { static assert(isOptimizer!(typeof(optim))); model.fc1.weight.data.zero_(); model.fc1.weight.grad = [[0.2f, 0.0f, 0.0f], [0.0f, 0.0f, 0.0f]].variable.data; - optim.update(model); + optim.update(); auto w = model.fc1.weight; assert(approxEqual(w.sliced, [[-lr * 0.2 / (0.2 * 0.2 + eps) ^^ 0.5, 0.0, 0.0], [0.0, 0.0, 0.0]].nparray)); auto m = optim.memory[".fc1.weight"].to!(typeof(w)); @@ -219,7 +228,7 @@ unittest { version (grain_cuda) { auto model = MLP!(float, DeviceStorage)(3); auto optim = AdaGrad!(typeof(model))(model, 0.1); - optim.update(model); + optim.update(); } } @@ -229,6 +238,7 @@ unittest { struct Adam(Chain) { import grain.autograd; + Chain* target; float lr = 1.0; float beta1 = 0.9; float beta2 = 0.999; @@ -236,10 +246,11 @@ struct Adam(Chain) { StateDict moment1, moment2; - this(ref Chain model, float lr, float eps=1e-8) { + this(ref Chain target, float lr, float eps=1e-8) { + this.target = ⌖ this.lr = lr; this.eps = eps; - iterVariables!((k, v) { this.initStates(k, v); })(model); + iterVariables!((k, v) { this.initStates(k, v); })(this.target); } void initStates(V)(string name, ref V field) if (isVariable!V) { @@ -280,7 +291,7 @@ unittest { static assert(isOptimizer!(typeof(optim))); model.fc1.weight.data.zero_(); model.fc1.weight.grad = [[0.2f, 0.0f, 0.0f], [0.0f, 0.0f, 0.0f]].variable.data; - optim.update(model); + optim.update(); auto w = model.fc1.weight; auto m1 = (1.0 - optim.beta1) * (0.2 - 0.0) + 0.0; auto m2 = (1.0 - optim.beta2) * (0.2 * 0.2 - 0.0) + 0.0; @@ -293,7 +304,7 @@ unittest { version (grain_cuda) { auto model = MLP!(float, DeviceStorage)(3); auto optim = Adam!(typeof(model))(model, 0.1); - optim.update(model); + optim.update(); } } @@ -302,17 +313,19 @@ unittest { struct AdaDelta(Chain) { import grain.autograd; + Chain* target; float lr = 1.0; float rho = 0.95; float eps = 1e-6; StateDict den, num; - this(ref Chain model, float lr=1.0, float rho=0.95, float eps=1e-8) { + this(ref Chain target, float lr=1.0, float rho=0.95, float eps=1e-8) { + this.target = ⌖ this.lr = lr; this.rho = rho; this.eps = eps; - iterVariables!((k, v) { this.initStates(k, v); })(model); + iterVariables!((k, v) { this.initStates(k, v); })(this.target); } void initStates(V)(string name, ref V field) if (isVariable!V) { @@ -353,7 +366,7 @@ unittest { // static assert(isOptimizer!(typeof(optim))); model.fc1.weight.data.zero_(); model.fc1.weight.grad = [[0.2f, 0.0f, 0.0f], [0.0f, 0.0f, 0.0f]].variable.data; - optim.update(model); + optim.update(); auto w = model.fc1.weight; auto d = (1.0 - optim.rho) * 0.2 * 0.2; auto diff = cast(float) ((0.0 + optim.eps) / (d + optim.eps)) ^^ 0.5; @@ -367,6 +380,6 @@ unittest { version (grain_cuda) { auto model = MLP!(float, DeviceStorage)(3); auto optim = AdaDelta!(typeof(model))(model); - optim.update(model); + optim.update(); } }