Skip to content

Commit

Permalink
Merge pull request #564 from takuseno/clean_implement_clip_grad_by_no…
Browse files Browse the repository at this point in the history
…rm_at_solver

Implement clip_grad_by_norm at solvers
  • Loading branch information
TakuyaNarihira authored Jan 8, 2020
2 parents 18b3238 + f49eb16 commit 5618029
Show file tree
Hide file tree
Showing 34 changed files with 149 additions and 1 deletion.
25 changes: 25 additions & 0 deletions include/nbla/solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,11 @@ class NBLA_API Solver {
*/
void weight_decay(float decay_rate);

/** Clip gradients by norm.
The norm is calculated at each variable.
*/
void clip_grad_by_norm(float norm);

/** Check if there is any inf on the gradients which were setup.
*/
bool check_inf_grad();
Expand Down Expand Up @@ -225,6 +230,15 @@ class NBLA_API Solver {
virtual void weight_decay_impl(const string &key, VariablePtr param,
float decay_rate) = 0;

/** Clip gradients by norm implementation.
@param key Key of parameter.
@param param A parameter Variable.
@param norm A value of norm.
*/
virtual void clip_grad_by_norm_impl(const string &key, VariablePtr param,
float clip_norm) = 0;

/** Check if there is any inf on the gradients which were setup.
*/
virtual bool check_inf_grad_impl(const string &key, VariablePtr param) = 0;
Expand Down Expand Up @@ -258,6 +272,17 @@ class NBLA_API Solver {
WEIGHT_DECAY_FUNC<T>(this->ctx_, param, decay_rate); \
}

#define NBLA_DECL_CLIP_GRAD_BY_NORM() \
virtual void clip_grad_by_norm_impl(const string &key, VariablePtr param, \
float clip_norm)

#define NBLA_DEF_CLIP_GRAD_BY_NORM(SOLVER, CLIP_GRAD_BY_NORM_FUNC) \
template <typename T> \
void SOLVER<T>::clip_grad_by_norm_impl(const string &key, VariablePtr param, \
float clip_norm) { \
CLIP_GRAD_BY_NORM_FUNC<T>(this->ctx_, param, clip_norm); \
}

#define NBLA_DECL_CHECK_INF_GRAD() \
virtual bool check_inf_grad_impl(const string &key, VariablePtr param)

Expand Down
1 change: 1 addition & 0 deletions include/nbla/solver/adabound.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ template <typename T> class NBLA_API AdaBound : public Solver {
virtual void remove_state_impl(const string &key);
virtual void update_impl(const string &key, VariablePtr param);
NBLA_DECL_WEIGHT_DECAY();
NBLA_DECL_CLIP_GRAD_BY_NORM();
NBLA_DECL_CHECK_INF_GRAD();
NBLA_DECL_CHECK_NAN_GRAD();
NBLA_DECL_CHECK_INF_OR_NAN_GRAD();
Expand Down
1 change: 1 addition & 0 deletions include/nbla/solver/adadelta.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ template <typename T> class NBLA_API Adadelta : public Solver {
virtual void remove_state_impl(const string &key);
virtual void update_impl(const string &key, VariablePtr param);
NBLA_DECL_WEIGHT_DECAY();
NBLA_DECL_CLIP_GRAD_BY_NORM();
NBLA_DECL_CHECK_INF_GRAD();
NBLA_DECL_CHECK_NAN_GRAD();
NBLA_DECL_CHECK_INF_OR_NAN_GRAD();
Expand Down
1 change: 1 addition & 0 deletions include/nbla/solver/adagrad.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ template <typename T> class NBLA_API Adagrad : public Solver {
virtual void remove_state_impl(const string &key);
virtual void update_impl(const string &key, VariablePtr param);
NBLA_DECL_WEIGHT_DECAY();
NBLA_DECL_CLIP_GRAD_BY_NORM();
NBLA_DECL_CHECK_INF_GRAD();
NBLA_DECL_CHECK_NAN_GRAD();
NBLA_DECL_CHECK_INF_OR_NAN_GRAD();
Expand Down
1 change: 1 addition & 0 deletions include/nbla/solver/adam.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ template <typename T> class NBLA_API Adam : public Solver {
virtual void remove_state_impl(const string &key);
virtual void update_impl(const string &key, VariablePtr param);
NBLA_DECL_WEIGHT_DECAY();
NBLA_DECL_CLIP_GRAD_BY_NORM();
NBLA_DECL_CHECK_INF_GRAD();
NBLA_DECL_CHECK_NAN_GRAD();
NBLA_DECL_CHECK_INF_OR_NAN_GRAD();
Expand Down
1 change: 1 addition & 0 deletions include/nbla/solver/adamax.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ template <typename T> class NBLA_API Adamax : public Solver {
virtual void remove_state_impl(const string &key);
virtual void update_impl(const string &key, VariablePtr param);
NBLA_DECL_WEIGHT_DECAY();
NBLA_DECL_CLIP_GRAD_BY_NORM();
NBLA_DECL_CHECK_INF_GRAD();
NBLA_DECL_CHECK_NAN_GRAD();
NBLA_DECL_CHECK_INF_OR_NAN_GRAD();
Expand Down
1 change: 1 addition & 0 deletions include/nbla/solver/adamw.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ template <typename T> class NBLA_API AdamW : public Solver {
virtual void remove_state_impl(const string &key);
virtual void update_impl(const string &key, VariablePtr param);
NBLA_DECL_WEIGHT_DECAY();
NBLA_DECL_CLIP_GRAD_BY_NORM();
NBLA_DECL_CHECK_INF_GRAD();
NBLA_DECL_CHECK_NAN_GRAD();
NBLA_DECL_CHECK_INF_OR_NAN_GRAD();
Expand Down
1 change: 1 addition & 0 deletions include/nbla/solver/amsbound.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ template <typename T> class NBLA_API AMSBound : public Solver {
virtual void remove_state_impl(const string &key);
virtual void update_impl(const string &key, VariablePtr param);
NBLA_DECL_WEIGHT_DECAY();
NBLA_DECL_CLIP_GRAD_BY_NORM();
NBLA_DECL_CHECK_INF_GRAD();
NBLA_DECL_CHECK_NAN_GRAD();
NBLA_DECL_CHECK_INF_OR_NAN_GRAD();
Expand Down
1 change: 1 addition & 0 deletions include/nbla/solver/amsgrad.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ template <typename T> class NBLA_API AMSGRAD : public Solver {
virtual void remove_state_impl(const string &key);
virtual void update_impl(const string &key, VariablePtr param);
NBLA_DECL_WEIGHT_DECAY();
NBLA_DECL_CLIP_GRAD_BY_NORM();
NBLA_DECL_CHECK_INF_GRAD();
NBLA_DECL_CHECK_NAN_GRAD();
NBLA_DECL_CHECK_INF_OR_NAN_GRAD();
Expand Down
41 changes: 41 additions & 0 deletions include/nbla/solver/clip_grad.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Copyright (c) 2017 Sony Corporation. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef __NBLA_SOLVER_CLIP_GRAD_HPP__
#define __NBLA_SOLVER_CLIP_GRAD_HPP__

#include <nbla/context.hpp>
#include <nbla/variable.hpp>

#include <memory>

namespace nbla {

template <typename T>
void clip_grad_by_norm_cpu(const Context &ctx, const shared_ptr<Variable> param,
float clip_norm) {
Size_t size = param->size();
T *grad = param->cast_grad_and_get_pointer<T>(ctx);
T sum = 0;
for (int i = 0; i < size; ++i)
sum += grad[i] * grad[i];
// sum > 0.0 is to avoid zero sqrt
if (sum > 0.0 && sum > clip_norm * clip_norm) {
T norm = std::sqrt(sum);
for (int i = 0; i < size; ++i)
grad[i] = clip_norm * grad[i] / norm;
}
}
}
#endif
1 change: 1 addition & 0 deletions include/nbla/solver/lars.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ template <typename T> class NBLA_API Lars : public Solver {
virtual void remove_state_impl(const string &key) override;
virtual void update_impl(const string &key, VariablePtr param) override;
NBLA_DECL_WEIGHT_DECAY();
NBLA_DECL_CLIP_GRAD_BY_NORM();
NBLA_DECL_CHECK_INF_GRAD();
NBLA_DECL_CHECK_NAN_GRAD();
NBLA_DECL_CHECK_INF_OR_NAN_GRAD();
Expand Down
1 change: 1 addition & 0 deletions include/nbla/solver/momentum.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ template <typename T> class NBLA_API Momentum : public Solver {
virtual void remove_state_impl(const string &key);
virtual void update_impl(const string &key, VariablePtr param);
NBLA_DECL_WEIGHT_DECAY();
NBLA_DECL_CLIP_GRAD_BY_NORM();
NBLA_DECL_CHECK_INF_GRAD();
NBLA_DECL_CHECK_NAN_GRAD();
NBLA_DECL_CHECK_INF_OR_NAN_GRAD();
Expand Down
1 change: 1 addition & 0 deletions include/nbla/solver/nesterov.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ template <typename T> class NBLA_API Nesterov : public Solver {
virtual void remove_state_impl(const string &key);
virtual void update_impl(const string &key, VariablePtr param);
NBLA_DECL_WEIGHT_DECAY();
NBLA_DECL_CLIP_GRAD_BY_NORM();
NBLA_DECL_CHECK_INF_GRAD();
NBLA_DECL_CHECK_NAN_GRAD();
NBLA_DECL_CHECK_INF_OR_NAN_GRAD();
Expand Down
1 change: 1 addition & 0 deletions include/nbla/solver/rmsprop.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ template <typename T> class NBLA_API RMSprop : public Solver {
virtual void remove_state_impl(const string &key);
virtual void update_impl(const string &key, VariablePtr param);
NBLA_DECL_WEIGHT_DECAY();
NBLA_DECL_CLIP_GRAD_BY_NORM();
NBLA_DECL_CHECK_INF_GRAD();
NBLA_DECL_CHECK_NAN_GRAD();
NBLA_DECL_CHECK_INF_OR_NAN_GRAD();
Expand Down
1 change: 1 addition & 0 deletions include/nbla/solver/sgd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ template <typename T> class NBLA_API Sgd : public Solver {
virtual void remove_state_impl(const string &key);
virtual void update_impl(const string &key, VariablePtr param);
NBLA_DECL_WEIGHT_DECAY();
NBLA_DECL_CLIP_GRAD_BY_NORM();
NBLA_DECL_CHECK_INF_GRAD();
NBLA_DECL_CHECK_NAN_GRAD();
NBLA_DECL_CHECK_INF_OR_NAN_GRAD();
Expand Down
1 change: 1 addition & 0 deletions include/nbla/solver/sgdw.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ template <typename T> class NBLA_API SgdW : public Solver {
virtual void remove_state_impl(const string &key);
virtual void update_impl(const string &key, VariablePtr param);
NBLA_DECL_WEIGHT_DECAY();
NBLA_DECL_CLIP_GRAD_BY_NORM();
NBLA_DECL_CHECK_INF_GRAD();
NBLA_DECL_CHECK_NAN_GRAD();
NBLA_DECL_CHECK_INF_OR_NAN_GRAD();
Expand Down
1 change: 1 addition & 0 deletions python/src/nnabla/solver.pxd.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ cdef extern from "nbla/solver.hpp" namespace "nbla":
void set_states(vector[pair[string, CSolverState]]) except +
void update() nogil except +
void weight_decay(float decay_rate) nogil except +
void clip_grad_by_norm(float clip_norm) nogil except +
cpp_bool check_inf_grad() nogil except +
cpp_bool check_nan_grad() nogil except +
cpp_bool check_inf_or_nan_grad() nogil except +
Expand Down
13 changes: 13 additions & 0 deletions python/src/nnabla/solver.pyx.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ cdef class Solver:
solver.zero_grad() # All gradient buffer being 0
loss.backward()
solver.weight_decay(decay_rate) # Apply weight decay
solver.clip_grad_by_norm(clip_norm) # Apply clip grad by norm
solver.update() # updating parameters

Note:
Expand Down Expand Up @@ -353,6 +354,18 @@ cdef class Solver:
with nogil:
self.solverp.weight_decay(decay_rate)

def clip_grad_by_norm(self, float clip_norm):
"""
Clip gradients by norm.
When called, the gradient will be clipped by the given norm.

Args:
clip_norm (float): The value of clipping norm.
"""

with nogil:
self.solverp.clip_grad_by_norm(clip_norm)

def check_inf_grad(self, ):
"""
Check if there is any inf on the gradients which were setup.
Expand Down
15 changes: 14 additions & 1 deletion python/test/solver/solver_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,14 @@ def weight_decay(self, grads, decay_rate):
param = self.params[key]
grad[...] = grad + decay_rate * param

def clip_grad_by_norm(self, grads, clip_norm):
for key, grad in iteritems(grads):
norm = np.sqrt(np.sum(grad ** 2))
grad[...] = clip_norm * grad / max(clip_norm, norm)


def solver_tester(rng, solver, ref_solver, solver_args=[], solver_kwargs={},
num_itr=5, decay=1e-4, atol=1e-6,
num_itr=5, decay=1e-4, clip_norm=0.5, atol=1e-6,
ctx=None, solver_name=None):
if ctx is None:
ctx = nn.Context()
Expand Down Expand Up @@ -89,6 +94,14 @@ def solver_tester(rng, solver, ref_solver, solver_args=[], solver_kwargs={},
for p, ref_p in zip(params.values(), grad_copy.values()):
assert_allclose(ref_p, p.g, atol=atol)

# Check clip grad by norm.
grad_copy = OrderedDict([(k, p.g.copy())
for k, p in iteritems(params)])
s.clip_grad_by_norm(clip_norm)
ref_s.clip_grad_by_norm(grad_copy, clip_norm)
for p, ref_p in zip(params.values(), grad_copy.values()):
assert np.allclose(ref_p, p.g, atol=atol)

# Check solver udpate.
for i in range(num_itr):
grads = OrderedDict([(k, rng.randn(*p.shape))
Expand Down
13 changes: 13 additions & 0 deletions src/nbla/solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,19 @@ void Solver::weight_decay(float decay_rate) {
}
}

void Solver::clip_grad_by_norm(float norm) {
if (norm == 0)
return;
for (auto &kv : params_) {
SyncedArrayPtr g = kv.second.p->grad()->array();
if (g->zeroing()) {
// The gradient is not computed. Skip.
continue;
}
clip_grad_by_norm_impl(kv.first, kv.second.p, norm);
}
}

bool Solver::check_inf_grad() {
for (auto &kv : params_) {
SyncedArrayPtr g = kv.second.p->grad()->array();
Expand Down
2 changes: 2 additions & 0 deletions src/nbla/solver/generic/adabound.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <cmath>
#include <limits>
#include <nbla/solver/adabound.hpp>
#include <nbla/solver/clip_grad.hpp>
#include <nbla/solver/mixed_precision_training.hpp>
#include <nbla/solver/weight_decay.hpp>

Expand Down Expand Up @@ -78,6 +79,7 @@ void AdaBound<T>::update_impl(const string &key, VariablePtr param) {
}

NBLA_DEF_WEIGHT_DECAY(AdaBound, weight_decay_cpu);
NBLA_DEF_CLIP_GRAD_BY_NORM(AdaBound, clip_grad_by_norm_cpu);
NBLA_DEF_CHECK_INF_GRAD(AdaBound, check_inf_grad_cpu);
NBLA_DEF_CHECK_NAN_GRAD(AdaBound, check_nan_grad_cpu);
NBLA_DEF_CHECK_INF_OR_NAN_GRAD(AdaBound, check_inf_or_nan_grad_cpu);
Expand Down
2 changes: 2 additions & 0 deletions src/nbla/solver/generic/adadelta.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <algorithm>
#include <cmath>
#include <nbla/solver/adadelta.hpp>
#include <nbla/solver/clip_grad.hpp>
#include <nbla/solver/mixed_precision_training.hpp>
#include <nbla/solver/weight_decay.hpp>

Expand Down Expand Up @@ -69,6 +70,7 @@ void Adadelta<T>::update_impl(const string &key, VariablePtr param) {
}

NBLA_DEF_WEIGHT_DECAY(Adadelta, weight_decay_cpu);
NBLA_DEF_CLIP_GRAD_BY_NORM(Adadelta, clip_grad_by_norm_cpu);
NBLA_DEF_CHECK_INF_GRAD(Adadelta, check_inf_grad_cpu);
NBLA_DEF_CHECK_NAN_GRAD(Adadelta, check_nan_grad_cpu);
NBLA_DEF_CHECK_INF_OR_NAN_GRAD(Adadelta, check_inf_or_nan_grad_cpu);
Expand Down
2 changes: 2 additions & 0 deletions src/nbla/solver/generic/adagrad.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <algorithm>
#include <cmath>
#include <nbla/solver/adagrad.hpp>
#include <nbla/solver/clip_grad.hpp>
#include <nbla/solver/mixed_precision_training.hpp>
#include <nbla/solver/weight_decay.hpp>

Expand Down Expand Up @@ -59,6 +60,7 @@ void Adagrad<T>::update_impl(const string &key, VariablePtr param) {
}

NBLA_DEF_WEIGHT_DECAY(Adagrad, weight_decay_cpu);
NBLA_DEF_CLIP_GRAD_BY_NORM(Adagrad, clip_grad_by_norm_cpu);
NBLA_DEF_CHECK_INF_GRAD(Adagrad, check_inf_grad_cpu);
NBLA_DEF_CHECK_NAN_GRAD(Adagrad, check_nan_grad_cpu);
NBLA_DEF_CHECK_INF_OR_NAN_GRAD(Adagrad, check_inf_or_nan_grad_cpu);
Expand Down
2 changes: 2 additions & 0 deletions src/nbla/solver/generic/adam.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <cmath>
#include <limits>
#include <nbla/solver/adam.hpp>
#include <nbla/solver/clip_grad.hpp>
#include <nbla/solver/mixed_precision_training.hpp>
#include <nbla/solver/weight_decay.hpp>

Expand Down Expand Up @@ -72,6 +73,7 @@ void Adam<T>::update_impl(const string &key, VariablePtr param) {
}

NBLA_DEF_WEIGHT_DECAY(Adam, weight_decay_cpu);
NBLA_DEF_CLIP_GRAD_BY_NORM(Adam, clip_grad_by_norm_cpu);
NBLA_DEF_CHECK_INF_GRAD(Adam, check_inf_grad_cpu);
NBLA_DEF_CHECK_NAN_GRAD(Adam, check_nan_grad_cpu);
NBLA_DEF_CHECK_INF_OR_NAN_GRAD(Adam, check_inf_or_nan_grad_cpu);
Expand Down
2 changes: 2 additions & 0 deletions src/nbla/solver/generic/adamax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <cmath>
#include <limits>
#include <nbla/solver/adamax.hpp>
#include <nbla/solver/clip_grad.hpp>
#include <nbla/solver/mixed_precision_training.hpp>
#include <nbla/solver/weight_decay.hpp>

Expand Down Expand Up @@ -71,6 +72,7 @@ void Adamax<T>::update_impl(const string &key, VariablePtr param) {
}

NBLA_DEF_WEIGHT_DECAY(Adamax, weight_decay_cpu);
NBLA_DEF_CLIP_GRAD_BY_NORM(Adamax, clip_grad_by_norm_cpu);
NBLA_DEF_CHECK_INF_GRAD(Adamax, check_inf_grad_cpu);
NBLA_DEF_CHECK_NAN_GRAD(Adamax, check_nan_grad_cpu);
NBLA_DEF_CHECK_INF_OR_NAN_GRAD(Adamax, check_inf_or_nan_grad_cpu);
Expand Down
2 changes: 2 additions & 0 deletions src/nbla/solver/generic/adamw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <cmath>
#include <limits>
#include <nbla/solver/adamw.hpp>
#include <nbla/solver/clip_grad.hpp>
#include <nbla/solver/mixed_precision_training.hpp>
#include <nbla/solver/weight_decay.hpp>

Expand Down Expand Up @@ -82,6 +83,7 @@ void AdamW<T>::weight_decay_impl(const string &key, VariablePtr param,
weight_decay_cpu<T>(this->ctx_, param, decay_rate);
}

NBLA_DEF_CLIP_GRAD_BY_NORM(AdamW, clip_grad_by_norm_cpu);
NBLA_DEF_CHECK_INF_GRAD(AdamW, check_inf_grad_cpu);
NBLA_DEF_CHECK_NAN_GRAD(AdamW, check_nan_grad_cpu);
NBLA_DEF_CHECK_INF_OR_NAN_GRAD(AdamW, check_inf_or_nan_grad_cpu);
Expand Down
2 changes: 2 additions & 0 deletions src/nbla/solver/generic/amsbound.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <cmath>
#include <limits>
#include <nbla/solver/amsbound.hpp>
#include <nbla/solver/clip_grad.hpp>
#include <nbla/solver/mixed_precision_training.hpp>
#include <nbla/solver/weight_decay.hpp>

Expand Down Expand Up @@ -87,6 +88,7 @@ void AMSBound<T>::update_impl(const string &key, VariablePtr param) {
}

NBLA_DEF_WEIGHT_DECAY(AMSBound, weight_decay_cpu);
NBLA_DEF_CLIP_GRAD_BY_NORM(AMSBound, clip_grad_by_norm_cpu);
NBLA_DEF_CHECK_INF_GRAD(AMSBound, check_inf_grad_cpu);
NBLA_DEF_CHECK_NAN_GRAD(AMSBound, check_nan_grad_cpu);
NBLA_DEF_CHECK_INF_OR_NAN_GRAD(AMSBound, check_inf_or_nan_grad_cpu);
Expand Down
Loading

0 comments on commit 5618029

Please sign in to comment.