Skip to content

Commit

Permalink
add adamw for deeplink ext
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangzefeng92 committed Apr 1, 2024
1 parent c98cb5c commit 1103b24
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 0 deletions.
12 changes: 12 additions & 0 deletions csrc/extensions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,15 @@ at::IntArrayRef optionalIntArrayToIntArrayRefOrDefault(

} // namespace

auto extAdamW(at::Tensor& param, at::Tensor& exp_avg, at::Tensor& exp_avg_sq,
at::Tensor& max_exp_avg_sq, at::Tensor& grad, float lr,
float beta1, float beta2, float epsilon, float weight_decay,
int64_t step, bool amsgrad) {
// the diopiAdamW func has no "maximize" param
callDiopi(diopiAdamW, param, grad, exp_avg, exp_avg_sq, max_exp_avg_sq, lr,

Check notice on line 47 in csrc/extensions.cpp

View workflow job for this annotation

GitHub Actions / static checks on sco

clangd Information [missing-includes]

No header providing "diopiAdamW" is directly included (fixes available)
beta1, beta2, epsilon, weight_decay, step, amsgrad);
}

auto extRmsNorm(const at::Tensor& input,
const OptionalIntArray& normalized_shape,
const at::Tensor& weight, const at::Tensor& bias, double eps) {
Expand Down Expand Up @@ -249,6 +258,9 @@ void extRotaryEmb(at::Tensor& q, const at::Tensor& cos, const at::Tensor& sin) {
// 否则不注册, 等到 python 层处理.
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
if (&diopiAdamW != nullptr) {

Check notice on line 261 in csrc/extensions.cpp

View workflow job for this annotation

GitHub Actions / static checks on sco

clangd Information [missing-includes]

No header providing "diopiAdamW" is directly included (fixes available)
m.def("adamw", &extAdamW, "deeplink ext_adamw");
}
if (&diopiRMSNorm != nullptr) { // Check if weak symbol defined
m.def("rms_norm", &extRmsNorm, "deeplink ext_rms_norm");
m.def("rms_norm_lightllm", &extRmsNormLightllm,
Expand Down
3 changes: 3 additions & 0 deletions deeplink_ext/ascend_speed/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .adamw import adamw

all = ["adamw"]
60 changes: 60 additions & 0 deletions deeplink_ext/ascend_speed/adamw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright (c) 2024, DeepLink.

from typing import Optional, Union, List
import torch
import deeplink_ext.cpp_extensions as ext

assert hasattr(ext, "adamw")


def adamw(
params: List[torch.Tensor],
grads: List[torch.Tensor],
exp_avgs: List[torch.Tensor],
exp_avg_sqs: List[torch.Tensor],
max_exp_avg_sqs: List[torch.Tensor],
state_steps: List[int],
*,
amsgrad: bool,
beta1: float,
beta2: float,
lr: float,
weight_decay: float,
eps: float,
maximize: bool,
norm_coeff_scale: float
):
r"""Functional API that performs AdamW algorithm computation.
See :class:`~torch.optim.AdamW` for details.
"""

assert maximize == False, "ascend diopiAdamW only support False 'maximize'."
assert amsgrad == False, "ascend diopiAdamW only support False 'amsgrad'."

for i, param in enumerate(params):
if norm_coeff_scale is not None:
grad = grads[i].float() * norm_coeff_scale
else:
grad = grads[i]
exp_avg = exp_avgs[i]
exp_avg_sq = exp_avg_sqs[i]
step = state_steps[i]
if not max_exp_avg_sqs:
max_exp_avg_sq = torch.Tensor().cuda()
else:
max_exp_avg_sq = max_exp_avg_sqs[i]
ext.adamw(
param,
exp_avg,
exp_avg_sq,
max_exp_avg_sq,
grad,
lr,
beta1,
beta2,
eps,
weight_decay,
step,
amsgrad,
)
return params, exp_avgs, exp_avg_sqs

0 comments on commit 1103b24

Please sign in to comment.