diff --git a/src/optim/init.lua b/src/optim/init.lua index 03d2563..62a7494 100644 --- a/src/optim/init.lua +++ b/src/optim/init.lua @@ -1,10 +1,11 @@ local util = require 'autograd.util' local function wrap(optimfn) - return function(fn, state, params) - local states = { } + return function(fn, config, state, params) + local configs, states = { }, { } local flatParams = util.sortedFlatten(params) for i = 1, #flatParams do + configs[i] = util.deepCopy(config) states[i] = util.deepCopy(state) end return function(...) @@ -15,10 +16,10 @@ local function wrap(optimfn) local grad = flatGrads[i] optimfn(function() return loss, grad - end, flatParams[i], states[i]) + end, flatParams[i], configs[i], states[i]) end return table.unpack(out) - end, states + end, configs, states end end @@ -28,4 +29,4 @@ for k, v in pairs(require 'optim') do opt[k] = wrap(v) end -return opt \ No newline at end of file +return opt diff --git a/test/test.lua b/test/test.lua index 63aa88c..f24b968 100644 --- a/test/test.lua +++ b/test/test.lua @@ -1530,10 +1530,10 @@ local tests = { end end - local state = { learningRate = learningRate } + local config, state = { learningRate = learningRate }, { } local loss3 for e=1, 5 do - local optimfn, states = autograd.optim.sgd(df, state, params3) + local optimfn, configs, states = autograd.optim.sgd(df, config, state, params3) loss3 = 0 for i=1,nData do local grads, loss = optimfn(xs:narrow(1, i, 1), ys:narrow(1, i, 1)) @@ -1583,7 +1583,7 @@ local tests = { local g = autograd(f, {optimize = true}) -- FAILS FOR OTHER OPTIMIZERS AS WELL - local optimfn, states = autograd.optim.sgd(g, {learningRate=1e-2}, params) + local optimfn, configs, states = autograd.optim.sgd(g, {learningRate=1e-2}, {}, params) for i=1,3 do -- Get images in BHWD format, labels in one-hot format: