diff --git a/init.lua b/init.lua index e0f179b..fadf9fe 100644 --- a/init.lua +++ b/init.lua @@ -117,6 +117,7 @@ torch.include('dp', 'visitor/maxnorm.lua') torch.include('dp', 'visitor/weightdecay.lua') torch.include('dp', 'visitor/learn.lua') torch.include('dp', 'visitor/momentum.lua') +torch.include('dp', 'visitor/gradclip.lua') --[[ observer ]]-- torch.include('dp', 'observer/observer.lua') diff --git a/model/layer.lua b/model/layer.lua index b8f74bb..8e74dd9 100644 --- a/model/layer.lua +++ b/model/layer.lua @@ -180,6 +180,24 @@ function Layer:maxNorm(max_out_norm, max_in_norm) end end +function Layer:gradClip(cutoff) + assert(self.backwarded, "Should call gradClip after a backward pass") + cutoff = self.mvstate.cutoff or cutoff + local params, gradParams = self:parameters() + local norm = 0 + for k,gradParam in pairs(gradParams) do + norm = norm + math.pow(gradParam:norm(),2) + end + norm = math.sqrt(norm) + if norm > cutoff then + -- rescale gradParams to obtain desired norm + for k,gradParam in pairs(gradParams) do + gradParam:mul(cutoff/norm) + end + end + return norm +end + function Layer:share(layer, ...) assert(layer.isLayer) local arg = {...} diff --git a/model/sequential.lua b/model/sequential.lua index e501f11..ca5ac8f 100644 --- a/model/sequential.lua +++ b/model/sequential.lua @@ -82,17 +82,3 @@ function Sequential:_toModule() self._models[i]:_toModule() end end - ---[[ --- experimental -function Sequential:flux(state) - local output = self.output - -- setup - for i=1,#self._models-1 do - self._models[i]:setSuccessor(self._models[i+1]) - end - return self._model[1]:flux() - self.input = output - return carry -end ---]] diff --git a/node.lua b/node.lua index 4a172f9..3955d4b 100644 --- a/node.lua +++ b/node.lua @@ -210,13 +210,3 @@ end function Node:_evaluate(carry) return self:_forward(carry) end - ---[[ --- experimental (would allow for one chained RPC call for both backward forward) -function Node:flux(carry) - local output, carry = self:forward() - local input, carry = self._successor:flux{output, carry} - local input, carry = self:backward{input, carry} - return input, carry -end ---]] diff --git a/visitor/gradclip.lua b/visitor/gradclip.lua new file mode 100644 index 0000000..3f9334b --- /dev/null +++ b/visitor/gradclip.lua @@ -0,0 +1,63 @@ +------------------------------------------------------------------------ +--[[ GradClip ]]-- +-- Ref.: A. http://goo.gl/Zxza8m +-- B. http://jmlr.org/proceedings/papers/v28/pascanu13.pdf +-- Visitor +-- Hard constraint on the upper bound of the norm of gradient with +-- respect to parameters (gradParams). Unlike ref A and B, which apply +-- the constraint on the norm of all parameters, the norm is applied +-- on the norm of each Layer's parameters. +-- Should occur before Learn in VisitorChain +------------------------------------------------------------------------ +local GradClip, parent = torch.class("dp.GradClip", "dp.Visitor") +GradClip.isGradClip = true + +function GradClip:__init(config) + config = config or {} + assert(torch.type(config) == 'table' and not config[1], + "Constructor requires key-value arguments") + local args, cutoff, name = xlua.unpack( + {config}, + 'GradClip', + 'Hard constraint on the upper bound of the norm of gradParams.', + {arg='cutoff', type='number', default=1, + help="max norm of a Layer's parameters"}, + {arg='name', type='string', default='gradclip', + help='identifies visitor in reports.'} + ) + self._cutoff = cutoff + config.include = config.include or {} + table.insert(config.include, 'hasParams') + config.exclude = config.exclude or {} + table.insert(config.exclude, 'no-gradclip') + config.name = name + parent.__init(self, config) + self.norms = {} +end + +function GradClip:_visitModel(model) + if model.gradClip then + local norm = model:gradClip(self._cutoff) + -- keep a moving average of norms + self.norms[model:id():toString()] = (self.norms[model:id():toString()] or 0)*0.8 + norm*0.2 + else + if not model.mvstate[self:id():name()].warned then + print("Warning: GradClip not implemented for model " .. + torch.typename(model) .. ". Ignoring model-visitor pair") + model.mvstate[self:id():name()].warned = true + end + end +end + +function GradClip:report() + local norms = _.values(self.norms) + if self._verbose then + print(self:id():name().." norms: ", unpack(norms)) + end + local report = { + [self:name()] = { + norms = self.norms + } + } + return report +end