Skip to content
This repository has been archived by the owner on Nov 1, 2021. It is now read-only.

Real optim state hidden because not passed in #175

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions src/optim/init.lua
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
local util = require 'autograd.util'

local function wrap(optimfn)
return function(fn, state, params)
local states = { }
return function(fn, config, state, params)
local configs, states = { }, { }
local flatParams = util.sortedFlatten(params)
for i = 1, #flatParams do
configs[i] = util.deepCopy(config)
states[i] = util.deepCopy(state)
end
return function(...)
Expand All @@ -15,10 +16,10 @@ local function wrap(optimfn)
local grad = flatGrads[i]
optimfn(function()
return loss, grad
end, flatParams[i], states[i])
end, flatParams[i], configs[i], states[i])
end
return table.unpack(out)
end, states
end, configs, states
end
end

Expand All @@ -28,4 +29,4 @@ for k, v in pairs(require 'optim') do
opt[k] = wrap(v)
end

return opt
return opt
6 changes: 3 additions & 3 deletions test/test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -1530,10 +1530,10 @@ local tests = {
end
end

local state = { learningRate = learningRate }
local config, state = { learningRate = learningRate }, { }
local loss3
for e=1, 5 do
local optimfn, states = autograd.optim.sgd(df, state, params3)
local optimfn, configs, states = autograd.optim.sgd(df, config, state, params3)
loss3 = 0
for i=1,nData do
local grads, loss = optimfn(xs:narrow(1, i, 1), ys:narrow(1, i, 1))
Expand Down Expand Up @@ -1583,7 +1583,7 @@ local tests = {
local g = autograd(f, {optimize = true})

-- FAILS FOR OTHER OPTIMIZERS AS WELL
local optimfn, states = autograd.optim.sgd(g, {learningRate=1e-2}, params)
local optimfn, configs, states = autograd.optim.sgd(g, {learningRate=1e-2}, {}, params)

for i=1,3 do
-- Get images in BHWD format, labels in one-hot format:
Expand Down