Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/master' into clean_implement_d…
Browse files Browse the repository at this point in the history
…istribution
  • Loading branch information
takuseno committed Jan 9, 2020
2 parents 96cf86e + 5618029 commit 2aa8449
Show file tree
Hide file tree
Showing 37 changed files with 175 additions and 7 deletions.
17 changes: 14 additions & 3 deletions doc/python/api/function.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ Neural Network Activation
.. autofunction:: tanh
.. autofunction:: relu
.. autofunction:: softmax
.. autofunction:: log_softmax
.. autofunction:: elu
.. autofunction:: selu
.. autofunction:: crelu
Expand All @@ -80,6 +81,7 @@ Normalization
-------------

.. autofunction:: batch_normalization
.. autofunction:: fused_batch_normalization
.. autofunction:: sync_batch_normalization
.. autofunction:: mean_subtraction
.. autofunction:: clip_by_value
Expand Down Expand Up @@ -146,6 +148,11 @@ Logical
.. autofunction:: maximum2
.. autofunction:: minimum_scalar
.. autofunction:: maximum_scalar
.. autofunction:: isnan
.. autofunction:: isinf
.. autofunction:: reset_nan
.. autofunction:: reset_inf
.. autofunction:: where


Math
Expand Down Expand Up @@ -200,6 +207,8 @@ Array Manipulation
.. autofunction:: batch_inv
.. autofunction:: batch_det
.. autofunction:: assign
.. autofunction:: top_k_data
.. autofunction:: top_k_grad


Stochasticity
Expand All @@ -209,8 +218,6 @@ Stochasticity
.. autofunction:: randint
.. autofunction:: randn
.. autofunction:: dropout
.. autofunction:: top_k_data
.. autofunction:: top_k_grad
.. autofunction:: random_choice
.. autofunction:: random_crop
.. autofunction:: random_flip
Expand Down Expand Up @@ -253,7 +260,9 @@ Quantized Neural Network Layers
.. autofunction:: min_max_quantize
.. autofunction:: pow2_quantize
.. autofunction:: prune

.. autofunction:: inq_affine
.. autofunction:: inq_convolution


Unsupported, Special Use
------------------------
Expand All @@ -262,6 +271,7 @@ Unsupported, Special Use
.. autofunction:: unlink
.. autofunction:: sink
.. autofunction:: warp_by_flow
.. autofunction:: confusion_matrix


Image Object Detection
Expand All @@ -274,3 +284,4 @@ Validation
----------

.. autofunction:: top_n_error
.. autofunction:: binary_error
13 changes: 10 additions & 3 deletions doc/python/api/parametric_function.rst
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ Here is the list of parametric functions.
.. autofunction:: deconvolution
.. autofunction:: depthwise_deconvolution
.. autofunction:: batch_normalization
.. autofunction:: fused_batch_normalization
.. autofunction:: sync_batch_normalization
.. autofunction:: mean_subtraction
.. autofunction:: layer_normalization
Expand Down Expand Up @@ -113,9 +114,9 @@ Here is the list of parametric functions.
.. autofunction:: spectral_norm
.. autofunction:: weight_normalization
.. autofunction:: multi_head_attention
.. autoclass:: transformer
.. autoclass:: transformer_encode
.. autoclass:: transformer_decode
.. autofunction:: transformer
.. autofunction:: transformer_encode
.. autofunction:: transformer_decode

Parameter Initializer
---------------------
Expand All @@ -138,6 +139,12 @@ listed below.
.. autoclass:: UniformInitializer
:show-inheritance:

.. autoclass:: UniformIntInitializer
:show-inheritance:

.. autoclass:: RangeInitializer
:show-inheritance:

.. autoclass:: OrthogonalInitializer
:show-inheritance:

Expand Down
2 changes: 2 additions & 0 deletions doc/python/api/solver.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,5 @@ List of solvers
.. autofunction:: Adamax
.. autofunction:: AMSGRAD
.. autofunction:: AMSBound
.. autofunction:: AdamW
.. autofunction:: SgdW
25 changes: 25 additions & 0 deletions include/nbla/solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,11 @@ class NBLA_API Solver {
*/
void weight_decay(float decay_rate);

/** Clip gradients by norm.
The norm is calculated at each variable.
*/
void clip_grad_by_norm(float norm);

/** Check if there is any inf on the gradients which were setup.
*/
bool check_inf_grad();
Expand Down Expand Up @@ -225,6 +230,15 @@ class NBLA_API Solver {
virtual void weight_decay_impl(const string &key, VariablePtr param,
float decay_rate) = 0;

/** Clip gradients by norm implementation.
@param key Key of parameter.
@param param A parameter Variable.
@param norm A value of norm.
*/
virtual void clip_grad_by_norm_impl(const string &key, VariablePtr param,
float clip_norm) = 0;

/** Check if there is any inf on the gradients which were setup.
*/
virtual bool check_inf_grad_impl(const string &key, VariablePtr param) = 0;
Expand Down Expand Up @@ -258,6 +272,17 @@ class NBLA_API Solver {
WEIGHT_DECAY_FUNC<T>(this->ctx_, param, decay_rate); \
}

#define NBLA_DECL_CLIP_GRAD_BY_NORM() \
virtual void clip_grad_by_norm_impl(const string &key, VariablePtr param, \
float clip_norm)

#define NBLA_DEF_CLIP_GRAD_BY_NORM(SOLVER, CLIP_GRAD_BY_NORM_FUNC) \
template <typename T> \
void SOLVER<T>::clip_grad_by_norm_impl(const string &key, VariablePtr param, \
float clip_norm) { \
CLIP_GRAD_BY_NORM_FUNC<T>(this->ctx_, param, clip_norm); \
}

#define NBLA_DECL_CHECK_INF_GRAD() \
virtual bool check_inf_grad_impl(const string &key, VariablePtr param)

Expand Down
1 change: 1 addition & 0 deletions include/nbla/solver/adabound.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ template <typename T> class NBLA_API AdaBound : public Solver {
virtual void remove_state_impl(const string &key);
virtual void update_impl(const string &key, VariablePtr param);
NBLA_DECL_WEIGHT_DECAY();
NBLA_DECL_CLIP_GRAD_BY_NORM();
NBLA_DECL_CHECK_INF_GRAD();
NBLA_DECL_CHECK_NAN_GRAD();
NBLA_DECL_CHECK_INF_OR_NAN_GRAD();
Expand Down
1 change: 1 addition & 0 deletions include/nbla/solver/adadelta.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ template <typename T> class NBLA_API Adadelta : public Solver {
virtual void remove_state_impl(const string &key);
virtual void update_impl(const string &key, VariablePtr param);
NBLA_DECL_WEIGHT_DECAY();
NBLA_DECL_CLIP_GRAD_BY_NORM();
NBLA_DECL_CHECK_INF_GRAD();
NBLA_DECL_CHECK_NAN_GRAD();
NBLA_DECL_CHECK_INF_OR_NAN_GRAD();
Expand Down
1 change: 1 addition & 0 deletions include/nbla/solver/adagrad.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ template <typename T> class NBLA_API Adagrad : public Solver {
virtual void remove_state_impl(const string &key);
virtual void update_impl(const string &key, VariablePtr param);
NBLA_DECL_WEIGHT_DECAY();
NBLA_DECL_CLIP_GRAD_BY_NORM();
NBLA_DECL_CHECK_INF_GRAD();
NBLA_DECL_CHECK_NAN_GRAD();
NBLA_DECL_CHECK_INF_OR_NAN_GRAD();
Expand Down
1 change: 1 addition & 0 deletions include/nbla/solver/adam.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ template <typename T> class NBLA_API Adam : public Solver {
virtual void remove_state_impl(const string &key);
virtual void update_impl(const string &key, VariablePtr param);
NBLA_DECL_WEIGHT_DECAY();
NBLA_DECL_CLIP_GRAD_BY_NORM();
NBLA_DECL_CHECK_INF_GRAD();
NBLA_DECL_CHECK_NAN_GRAD();
NBLA_DECL_CHECK_INF_OR_NAN_GRAD();
Expand Down
1 change: 1 addition & 0 deletions include/nbla/solver/adamax.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ template <typename T> class NBLA_API Adamax : public Solver {
virtual void remove_state_impl(const string &key);
virtual void update_impl(const string &key, VariablePtr param);
NBLA_DECL_WEIGHT_DECAY();
NBLA_DECL_CLIP_GRAD_BY_NORM();
NBLA_DECL_CHECK_INF_GRAD();
NBLA_DECL_CHECK_NAN_GRAD();
NBLA_DECL_CHECK_INF_OR_NAN_GRAD();
Expand Down
1 change: 1 addition & 0 deletions include/nbla/solver/adamw.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ template <typename T> class NBLA_API AdamW : public Solver {
virtual void remove_state_impl(const string &key);
virtual void update_impl(const string &key, VariablePtr param);
NBLA_DECL_WEIGHT_DECAY();
NBLA_DECL_CLIP_GRAD_BY_NORM();
NBLA_DECL_CHECK_INF_GRAD();
NBLA_DECL_CHECK_NAN_GRAD();
NBLA_DECL_CHECK_INF_OR_NAN_GRAD();
Expand Down
1 change: 1 addition & 0 deletions include/nbla/solver/amsbound.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ template <typename T> class NBLA_API AMSBound : public Solver {
virtual void remove_state_impl(const string &key);
virtual void update_impl(const string &key, VariablePtr param);
NBLA_DECL_WEIGHT_DECAY();
NBLA_DECL_CLIP_GRAD_BY_NORM();
NBLA_DECL_CHECK_INF_GRAD();
NBLA_DECL_CHECK_NAN_GRAD();
NBLA_DECL_CHECK_INF_OR_NAN_GRAD();
Expand Down
1 change: 1 addition & 0 deletions include/nbla/solver/amsgrad.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ template <typename T> class NBLA_API AMSGRAD : public Solver {
virtual void remove_state_impl(const string &key);
virtual void update_impl(const string &key, VariablePtr param);
NBLA_DECL_WEIGHT_DECAY();
NBLA_DECL_CLIP_GRAD_BY_NORM();
NBLA_DECL_CHECK_INF_GRAD();
NBLA_DECL_CHECK_NAN_GRAD();
NBLA_DECL_CHECK_INF_OR_NAN_GRAD();
Expand Down
41 changes: 41 additions & 0 deletions include/nbla/solver/clip_grad.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Copyright (c) 2017 Sony Corporation. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef __NBLA_SOLVER_CLIP_GRAD_HPP__
#define __NBLA_SOLVER_CLIP_GRAD_HPP__

#include <nbla/context.hpp>
#include <nbla/variable.hpp>

#include <memory>

namespace nbla {

template <typename T>
void clip_grad_by_norm_cpu(const Context &ctx, const shared_ptr<Variable> param,
float clip_norm) {
Size_t size = param->size();
T *grad = param->cast_grad_and_get_pointer<T>(ctx);
T sum = 0;
for (int i = 0; i < size; ++i)
sum += grad[i] * grad[i];
// sum > 0.0 is to avoid zero sqrt
if (sum > 0.0 && sum > clip_norm * clip_norm) {
T norm = std::sqrt(sum);
for (int i = 0; i < size; ++i)
grad[i] = clip_norm * grad[i] / norm;
}
}
}
#endif
1 change: 1 addition & 0 deletions include/nbla/solver/lars.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ template <typename T> class NBLA_API Lars : public Solver {
virtual void remove_state_impl(const string &key) override;
virtual void update_impl(const string &key, VariablePtr param) override;
NBLA_DECL_WEIGHT_DECAY();
NBLA_DECL_CLIP_GRAD_BY_NORM();
NBLA_DECL_CHECK_INF_GRAD();
NBLA_DECL_CHECK_NAN_GRAD();
NBLA_DECL_CHECK_INF_OR_NAN_GRAD();
Expand Down
1 change: 1 addition & 0 deletions include/nbla/solver/momentum.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ template <typename T> class NBLA_API Momentum : public Solver {
virtual void remove_state_impl(const string &key);
virtual void update_impl(const string &key, VariablePtr param);
NBLA_DECL_WEIGHT_DECAY();
NBLA_DECL_CLIP_GRAD_BY_NORM();
NBLA_DECL_CHECK_INF_GRAD();
NBLA_DECL_CHECK_NAN_GRAD();
NBLA_DECL_CHECK_INF_OR_NAN_GRAD();
Expand Down
1 change: 1 addition & 0 deletions include/nbla/solver/nesterov.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ template <typename T> class NBLA_API Nesterov : public Solver {
virtual void remove_state_impl(const string &key);
virtual void update_impl(const string &key, VariablePtr param);
NBLA_DECL_WEIGHT_DECAY();
NBLA_DECL_CLIP_GRAD_BY_NORM();
NBLA_DECL_CHECK_INF_GRAD();
NBLA_DECL_CHECK_NAN_GRAD();
NBLA_DECL_CHECK_INF_OR_NAN_GRAD();
Expand Down
1 change: 1 addition & 0 deletions include/nbla/solver/rmsprop.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ template <typename T> class NBLA_API RMSprop : public Solver {
virtual void remove_state_impl(const string &key);
virtual void update_impl(const string &key, VariablePtr param);
NBLA_DECL_WEIGHT_DECAY();
NBLA_DECL_CLIP_GRAD_BY_NORM();
NBLA_DECL_CHECK_INF_GRAD();
NBLA_DECL_CHECK_NAN_GRAD();
NBLA_DECL_CHECK_INF_OR_NAN_GRAD();
Expand Down
1 change: 1 addition & 0 deletions include/nbla/solver/sgd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ template <typename T> class NBLA_API Sgd : public Solver {
virtual void remove_state_impl(const string &key);
virtual void update_impl(const string &key, VariablePtr param);
NBLA_DECL_WEIGHT_DECAY();
NBLA_DECL_CLIP_GRAD_BY_NORM();
NBLA_DECL_CHECK_INF_GRAD();
NBLA_DECL_CHECK_NAN_GRAD();
NBLA_DECL_CHECK_INF_OR_NAN_GRAD();
Expand Down
1 change: 1 addition & 0 deletions include/nbla/solver/sgdw.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ template <typename T> class NBLA_API SgdW : public Solver {
virtual void remove_state_impl(const string &key);
virtual void update_impl(const string &key, VariablePtr param);
NBLA_DECL_WEIGHT_DECAY();
NBLA_DECL_CLIP_GRAD_BY_NORM();
NBLA_DECL_CHECK_INF_GRAD();
NBLA_DECL_CHECK_NAN_GRAD();
NBLA_DECL_CHECK_INF_OR_NAN_GRAD();
Expand Down
1 change: 1 addition & 0 deletions python/src/nnabla/solver.pxd.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ cdef extern from "nbla/solver.hpp" namespace "nbla":
void set_states(vector[pair[string, CSolverState]]) except +
void update() nogil except +
void weight_decay(float decay_rate) nogil except +
void clip_grad_by_norm(float clip_norm) nogil except +
cpp_bool check_inf_grad() nogil except +
cpp_bool check_nan_grad() nogil except +
cpp_bool check_inf_or_nan_grad() nogil except +
Expand Down
13 changes: 13 additions & 0 deletions python/src/nnabla/solver.pyx.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ cdef class Solver:
solver.zero_grad() # All gradient buffer being 0
loss.backward()
solver.weight_decay(decay_rate) # Apply weight decay
solver.clip_grad_by_norm(clip_norm) # Apply clip grad by norm
solver.update() # updating parameters

Note:
Expand Down Expand Up @@ -353,6 +354,18 @@ cdef class Solver:
with nogil:
self.solverp.weight_decay(decay_rate)

def clip_grad_by_norm(self, float clip_norm):
"""
Clip gradients by norm.
When called, the gradient will be clipped by the given norm.

Args:
clip_norm (float): The value of clipping norm.
"""

with nogil:
self.solverp.clip_grad_by_norm(clip_norm)

def check_inf_grad(self, ):
"""
Check if there is any inf on the gradients which were setup.
Expand Down
15 changes: 14 additions & 1 deletion python/test/solver/solver_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,14 @@ def weight_decay(self, grads, decay_rate):
param = self.params[key]
grad[...] = grad + decay_rate * param

def clip_grad_by_norm(self, grads, clip_norm):
for key, grad in iteritems(grads):
norm = np.sqrt(np.sum(grad ** 2))
grad[...] = clip_norm * grad / max(clip_norm, norm)


def solver_tester(rng, solver, ref_solver, solver_args=[], solver_kwargs={},
num_itr=5, decay=1e-4, atol=1e-6,
num_itr=5, decay=1e-4, clip_norm=0.5, atol=1e-6,
ctx=None, solver_name=None):
if ctx is None:
ctx = nn.Context()
Expand Down Expand Up @@ -89,6 +94,14 @@ def solver_tester(rng, solver, ref_solver, solver_args=[], solver_kwargs={},
for p, ref_p in zip(params.values(), grad_copy.values()):
assert_allclose(ref_p, p.g, atol=atol)

# Check clip grad by norm.
grad_copy = OrderedDict([(k, p.g.copy())
for k, p in iteritems(params)])
s.clip_grad_by_norm(clip_norm)
ref_s.clip_grad_by_norm(grad_copy, clip_norm)
for p, ref_p in zip(params.values(), grad_copy.values()):
assert np.allclose(ref_p, p.g, atol=atol)

# Check solver udpate.
for i in range(num_itr):
grads = OrderedDict([(k, rng.randn(*p.shape))
Expand Down
13 changes: 13 additions & 0 deletions src/nbla/solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,19 @@ void Solver::weight_decay(float decay_rate) {
}
}

void Solver::clip_grad_by_norm(float norm) {
if (norm == 0)
return;
for (auto &kv : params_) {
SyncedArrayPtr g = kv.second.p->grad()->array();
if (g->zeroing()) {
// The gradient is not computed. Skip.
continue;
}
clip_grad_by_norm_impl(kv.first, kv.second.p, norm);
}
}

bool Solver::check_inf_grad() {
for (auto &kv : params_) {
SyncedArrayPtr g = kv.second.p->grad()->array();
Expand Down
Loading

0 comments on commit 2aa8449

Please sign in to comment.