-
Notifications
You must be signed in to change notification settings - Fork 34
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #464 from sony/feature/20230216-lion-optimizer
Lion optimizer
- Loading branch information
Showing
3 changed files
with
122 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -32,3 +32,5 @@ AMSBound: | |
float: [float] | ||
Lamb: | ||
float: [float] | ||
Lion: | ||
float: [float] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
// Copyright 2023 Sony Group Corporation. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#ifndef __NBLA_CUDA_SOLVER_LION_HPP__ | ||
#define __NBLA_CUDA_SOLVER_LION_HPP__ | ||
|
||
#include <nbla/cuda/cuda.hpp> | ||
#include <nbla/solver/lion.hpp> | ||
|
||
namespace nbla { | ||
|
||
template <typename T> class LionCuda : public Lion<T> { | ||
public: | ||
public: | ||
explicit LionCuda(const Context &ctx, float lr, float beta1, float beta2) | ||
: Lion<T>(ctx, lr, beta1, beta2) {} | ||
virtual ~LionCuda() {} | ||
virtual string name() { return "LionCuda"; } | ||
virtual vector<string> allowed_array_classes() { | ||
return SingletonManager::get<Cuda>()->array_classes(); | ||
} | ||
|
||
protected: | ||
virtual void update_impl(const string &key, VariablePtr param) override; | ||
NBLA_DECL_WEIGHT_DECAY(); | ||
NBLA_DECL_CLIP_GRAD_BY_NORM(); | ||
NBLA_DECL_CHECK_INF_GRAD(); | ||
NBLA_DECL_CHECK_NAN_GRAD(); | ||
NBLA_DECL_CHECK_INF_OR_NAN_GRAD(); | ||
NBLA_DECL_SCALE_GRAD(); | ||
}; | ||
} | ||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
// Copyright 2023 Sony Group Corporation. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include <cassert> | ||
#include <queue> | ||
|
||
#include <nbla/cuda/array/cuda_array.hpp> | ||
#include <nbla/cuda/common.hpp> | ||
#include <nbla/cuda/cuda.hpp> | ||
#include <nbla/cuda/solver/lion.hpp> | ||
|
||
#include "./clip_grad.cuh" | ||
#include "./mixed_precision_training.cuh" | ||
#include "./weight_decay.cuh" | ||
|
||
namespace nbla { | ||
|
||
namespace { | ||
template <typename T> | ||
__forceinline__ __device__ T lerp(const T a, const T b, const float t) { | ||
return a + t * (b - a); | ||
} | ||
template <typename T> __forceinline__ __device__ int sign(const T x) { | ||
return (x > T(0)) - (x < T(0)); | ||
} | ||
} | ||
|
||
template <typename T> | ||
__global__ void kernel_lion_update(const int num, const T *g, T *m, T *w, | ||
const float lr, const float beta1, | ||
const float beta2, const float decay_rate) { | ||
NBLA_CUDA_KERNEL_LOOP(idx, num) { | ||
auto u = sign(lerp(g[idx], m[idx], beta1)); | ||
m[idx] = lerp(g[idx], m[idx], beta2); | ||
w[idx] -= lr * (u + decay_rate * w[idx]); | ||
} | ||
} | ||
|
||
template <typename T> | ||
void LionCuda<T>::update_impl(const string &key, VariablePtr param) { | ||
typedef typename CudaType<T>::type Tc; | ||
cuda_set_device(std::stoi(this->ctx_.device_id)); | ||
dtypes dtype = get_dtype<Tc>(); | ||
|
||
auto &t = this->states_.at(key).t; | ||
t = std::min(t + 1, std::numeric_limits<uint32_t>::max() - 1); | ||
|
||
Size_t size = param->size(); | ||
VariablePtr m_var = this->states_.at(key).pstate["m"]; | ||
const Tc *g = param->get_grad_pointer<Tc>(this->ctx_); | ||
Tc *w = param->cast_data_and_get_pointer<Tc>(this->ctx_); | ||
Tc *m = m_var->cast_data_and_get_pointer<Tc>(this->ctx_); | ||
|
||
NBLA_CUDA_LAUNCH_KERNEL_SIMPLE(kernel_lion_update, size, g, m, w, this->lr_, | ||
this->beta1_, this->beta2_, | ||
this->weight_decay_rate_); | ||
} | ||
|
||
NBLA_DEF_WEIGHT_DECAY(LionCuda, weight_decay_cuda); | ||
NBLA_DEF_CLIP_GRAD_BY_NORM(LionCuda, clip_grad_by_norm_cuda); | ||
NBLA_DEF_CHECK_INF_GRAD(LionCuda, check_inf_grad_cuda); | ||
NBLA_DEF_CHECK_NAN_GRAD(LionCuda, check_nan_grad_cuda); | ||
NBLA_DEF_CHECK_INF_OR_NAN_GRAD(LionCuda, check_inf_or_nan_grad_cuda); | ||
NBLA_DEF_SCALE_GRAD(LionCuda, scale_grad_impl_cuda); | ||
} |