Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat fused_adamw #938

Merged
merged 9 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dipu/SupportedDiopiFunctions.txt
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ diopiForeachmulInpTensor
diopiForeachmulScalar
diopiForeachmulTensor
diopiForeachnormScalar
diopiFusedAdamW
diopiGather
diopiGe
diopiGeInp
Expand Down
36 changes: 36 additions & 0 deletions dipu/scripts/autogen_diopi_wrapper/diopi_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1325,6 +1325,42 @@
::diopiConstTensorHandle_t self_dtype_diopi = dipu::diopi_helper::toDiopiTensorHandle(self_dtype);
interface: diopiProd(ctx, out, self_dtype_diopi, nullptr)

- schema: "_fused_adamw_(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, float lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> ()"
custom_code_at_the_beginning: |
std::vector<diopiTensorHandle_t> diopiTensorHandles_self(self.size());
for(size_t i=0; i < self.size(); ++i){
diopiConstTensorHandle_t const_handle = dipu::diopi_helper::toDiopiTensorHandle(self.at(i));
diopiTensorHandle_t handle = const_cast<diopiTensorHandle_t>(const_handle);
diopiTensorHandles_self[i] = handle;
}
std::vector<diopiConstTensorHandle_t> diopiTensorHandles_grads(grads.size());
for(size_t i=0; i < grads.size(); ++i){
diopiTensorHandles_grads[i] = dipu::diopi_helper::toDiopiTensorHandle(grads.at(i));
}
std::vector<diopiTensorHandle_t> diopiTensorHandles_exp_avgs(exp_avgs.size());
for(size_t i=0; i < exp_avgs.size(); ++i){
diopiConstTensorHandle_t const_handle = dipu::diopi_helper::toDiopiTensorHandle(exp_avgs.at(i));
diopiTensorHandle_t handle = const_cast<diopiTensorHandle_t>(const_handle);
diopiTensorHandles_exp_avgs[i] = handle;
}
std::vector<diopiTensorHandle_t> diopiTensorHandles_exp_avg_sqs(exp_avg_sqs.size());
for(size_t i=0; i < exp_avg_sqs.size(); ++i){
diopiConstTensorHandle_t const_handle = dipu::diopi_helper::toDiopiTensorHandle(exp_avg_sqs.at(i));
diopiTensorHandle_t handle = const_cast<diopiTensorHandle_t>(const_handle);
diopiTensorHandles_exp_avg_sqs[i] = handle;
}
std::vector<diopiTensorHandle_t> diopiTensorHandles_max_exp_avg_sqs(max_exp_avg_sqs.size());
for(size_t i=0; i < max_exp_avg_sqs.size(); ++i){
diopiConstTensorHandle_t const_handle = dipu::diopi_helper::toDiopiTensorHandle(max_exp_avg_sqs.at(i));
diopiTensorHandle_t handle = const_cast<diopiTensorHandle_t>(const_handle);
diopiTensorHandles_max_exp_avg_sqs[i] = handle;
}
std::vector<diopiConstTensorHandle_t> diopiTensorHandles_state_steps(state_steps.size(), nullptr);
for(size_t i=0; i < state_steps.size(); ++i){
diopiTensorHandles_state_steps[i] = dipu::diopi_helper::toDiopiTensorHandle(state_steps.at(i));
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

削减重复代码,考虑:

  1. 提取函数
  2. 使用 std::transform

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

toDiopiTensorHandleVector
已经有这个函数了

interface: diopiFusedAdamW(ctx, diopiTensorHandles_self.data(), diopiTensorHandles_grads.data(), diopiTensorHandles_exp_avgs.data(), diopiTensorHandles_exp_avg_sqs.data(), diopiTensorHandles_max_exp_avg_sqs.data(), diopiTensorHandles_state_steps.data(), static_cast<int64_t>(self.size()), lr, beta1, beta2, eps, weight_decay, amsgrad, maximize)

- schema: prod.int_out(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
custom_code_at_the_beginning: |
const auto self_dtype = at::native::to(self, dtype);
Expand Down
202 changes: 202 additions & 0 deletions dipu/tests/python/unittests/test_adamw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
import torch
import numpy as np
from torch_dipu.testing._internal.common_utils import TestCase, run_tests, onlyOn


class TestFusedAdamW(TestCase):
def setUp(self):
self.weight_shape_list = [(), (16,), (4, 8), (12, 4, 8)]
self.lr_list = [0.001, 0.01, 0.001, 0.01]
self.beta1_list = [0.9, 0.9, 0.9, 0.9]
self.beta2_list = [0.999, 0.999, 0.999, 0.999]
self.eps_list = [1e-8, 1e-8, 1e-8, 1e-8]
self.weight_decay_list = [1e-2, 1e-3, 1e-2, 1e-3]
self.amsgrad_list = [False, False, True, True]
self.step_list = [2, 3, 4, 5]

def run_adamw_cpu(
self,
param,
param_grad,
exp_avg,
exp_avg_sq,
max_exp_avg_sq,
lr,
beta1,
beta2,
eps,
step,
weight_decay,
amsgrad,
):
torch.optim._functional.adamw(
[param],
[param_grad],
[exp_avg],
[exp_avg_sq],
[max_exp_avg_sq],
[torch.tensor(float(step))],
amsgrad=amsgrad,
beta1=beta1,
beta2=beta2,
lr=lr,
weight_decay=weight_decay,
eps=eps,
maximize=False,
)
return param, exp_avg, exp_avg_sq, max_exp_avg_sq

def run_adamw_dipu(
self,
param,
param_grad,
exp_avg,
exp_avg_sq,
max_exp_avg_sq,
lr,
beta1,
beta2,
eps,
step,
weight_decay,
amsgrad,
):
torch._fused_adamw_(
[param],
[param_grad],
[exp_avg],
[exp_avg_sq],
[max_exp_avg_sq],
[torch.tensor(float(step)).cuda()],
amsgrad=amsgrad,
lr=lr,
beta1=beta1,
beta2=beta2,
weight_decay=weight_decay,
eps=eps,
maximize=False,
grad_scale=None,
found_inf=None,
)
return param, exp_avg, exp_avg_sq, max_exp_avg_sq

def adamw_(self, dtype_):
for i in range(len(self.weight_shape_list)):
weight = torch.randn(self.weight_shape_list[i], dtype=dtype_).cuda()
weight_cpu = (
weight.cpu().to(torch.float32)
if dtype_ == torch.float16
else weight.cpu()
)
grad = torch.randn(self.weight_shape_list[i], dtype=dtype_).cuda()
grad_cpu = (
grad.cpu().to(torch.float32) if dtype_ == torch.float16 else grad.cpu()
)
m = torch.randn(self.weight_shape_list[i], dtype=dtype_).cuda()
m_cpu = m.cpu().to(torch.float32) if dtype_ == torch.float16 else m.cpu()
v = torch.randn(self.weight_shape_list[i], dtype=dtype_).cuda()
v_cpu = v.cpu().to(torch.float32) if dtype_ == torch.float16 else v.cpu()
max_v = torch.randn(self.weight_shape_list[i], dtype=dtype_).cuda()
max_v_cpu = (
max_v.cpu().to(torch.float32)
if dtype_ == torch.float16
else max_v.cpu()
)

lr = self.lr_list[i]
beta1 = self.beta1_list[i]
beta2 = self.beta2_list[i]
eps = self.eps_list[i]
weight_decay = self.weight_decay_list[i]
amsgrad = self.amsgrad_list[i]
step = self.step_list[i]

w_new_cpu, m_new_cpu, v_new_cpu, max_v_new_cpu = self.run_adamw_cpu(
weight_cpu,
grad_cpu,
m_cpu,
v_cpu,
max_v_cpu,
lr,
beta1,
beta2,
eps,
step,
weight_decay,
amsgrad,
)
w_new, m_new, v_new, max_v_new = self.run_adamw_dipu(
weight,
grad,
m,
v,
max_v,
lr,
beta1,
beta2,
eps,
step,
weight_decay,
amsgrad,
)

self.assertTrue(
torch.allclose(
w_new.cpu(),
(
w_new_cpu.to(torch.float16)
if dtype_ == torch.float16
else w_new_cpu
),
atol=2e-2 if dtype_ == torch.float16 else 1e-2,
rtol=4e-3 if dtype_ == torch.float16 else 2e-3,
equal_nan=True,
),
torch.allclose(
m_new.cpu(),
(
m_new_cpu.to(torch.float16)
if dtype_ == torch.float16
else m_new_cpu
),
atol=2e-2 if dtype_ == torch.float16 else 1e-2,
rtol=4e-3 if dtype_ == torch.float16 else 2e-3,
equal_nan=True,
),
)
self.assertTrue(
torch.allclose(
v_new.cpu(),
(
v_new_cpu.to(torch.float16)
if dtype_ == torch.float16
else v_new_cpu
),
atol=2e-2 if dtype_ == torch.float16 else 1e-2,
rtol=4e-3 if dtype_ == torch.float16 else 2e-3,
equal_nan=True,
),
torch.allclose(
max_v_new.cpu(),
(
max_v_new_cpu.to(torch.float16)
if dtype_ == torch.float16
else max_v_new_cpu
),
atol=2e-2 if dtype_ == torch.float16 else 1e-2,
rtol=4e-3 if dtype_ == torch.float16 else 2e-3,
equal_nan=True,
),
)

@onlyOn("CUDA")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果期望所有芯片都实现,应该用 skipon 比较好

def test_adamw_fp16_(self):
self.adamw_(torch.float16)

@onlyOn("CUDA")
def test_adamw_fp32_(self):
self.adamw_(torch.float32)


if __name__ == "__main__":
run_tests()