-
Notifications
You must be signed in to change notification settings - Fork 2
/
optim-adamax-single.lua
38 lines (32 loc) · 1.15 KB
/
optim-adamax-single.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
-- Adamax
-- only for single worker
-- Author: Minwei Feng ([email protected])
require 'optim'
function optim.adamaxsingle(opfunc, w, config, state)
local config = config or {}
local state = state or config
local lr = config.lr
local beta1 = config.beta1
local beta2 = config.beta2
local epsilon = config.epsilon
local pc = config.pclient or nil
state.pversion = state.pversion or 0
local fx,dfdx = opfunc(w)
state.adamax_t = state.adamax_t or 0
state.adamax_m = state.adamax_m or torch.Tensor():resizeAs(dfdx):zero()
state.adamax_u = state.adamax_u or torch.Tensor():resizeAs(dfdx):zero()
state.adamax_max = state.adamax_max or w.new(2, unpack(dfdx:size():totable())):zero()
state.adamax_t = state.adamax_t + 1
state.adamax_m:mul(beta1):add(1-beta1, dfdx)
state.adamax_max[1]:copy(state.adamax_u):mul(beta2)
state.adamax_max[2]:copy(dfdx):abs():add(epsilon)
state.adamax_u:max(state.adamax_max, 1)
local beta1_t = 1 - math.pow(beta1, state.adamax_t )
local lr_t = lr /beta1_t
w:addcdiv(-lr_t, state.adamax_m, state.adamax_u)
state.pversion = state.pversion + 1
-- send
pc:async_send_param()
pc:wait()
return w,{fx}
end