Skip to content

Commit

Permalink
add test code
Browse files Browse the repository at this point in the history
  • Loading branch information
zyf654321 committed Sep 20, 2024
1 parent e76a5bf commit ffcc77f
Showing 1 changed file with 150 additions and 46 deletions.
196 changes: 150 additions & 46 deletions dipu/tests/python/unittests/test_adamw.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
from torch_dipu.testing._internal.common_utils import TestCase, run_tests


class TestFusedAdamW(TestCase):
def setUp(self):
self.weight_shape_list = [(), (16,), (4, 8), (12, 4, 8)]
Expand All @@ -13,57 +14,95 @@ def setUp(self):
self.amsgrad_list = [False, False, True, True]
self.step_list = [2, 3, 4, 5]

def run_adamw_cpu(self, param, param_grad, exp_avg, exp_avg_sq, max_exp_avg_sq, lr, beta1, beta2, eps, step, weight_decay, amsgrad):
def run_adamw_cpu(
self,
param,
param_grad,
exp_avg,
exp_avg_sq,
max_exp_avg_sq,
lr,
beta1,
beta2,
eps,
step,
weight_decay,
amsgrad,
):
torch.optim._functional.adamw(
[param],
[param_grad],
[exp_avg],
[exp_avg_sq],
[max_exp_avg_sq],
[torch.tensor(float(step))],
amsgrad=amsgrad,
beta1=beta1,
beta2=beta2,
lr=lr,
weight_decay=weight_decay,
eps=eps,
maximize=False,
)
[param],
[param_grad],
[exp_avg],
[exp_avg_sq],
[max_exp_avg_sq],
[torch.tensor(float(step))],
amsgrad=amsgrad,
beta1=beta1,
beta2=beta2,
lr=lr,
weight_decay=weight_decay,
eps=eps,
maximize=False,
)
return param, exp_avg, exp_avg_sq, max_exp_avg_sq

def run_adamw_dipu(self, param, param_grad, exp_avg, exp_avg_sq, max_exp_avg_sq, lr, beta1, beta2, eps, step, weight_decay, amsgrad):
def run_adamw_dipu(
self,
param,
param_grad,
exp_avg,
exp_avg_sq,
max_exp_avg_sq,
lr,
beta1,
beta2,
eps,
step,
weight_decay,
amsgrad,
):
torch._fused_adamw_(
[param],
[param_grad],
[exp_avg],
[exp_avg_sq],
[max_exp_avg_sq],
[torch.tensor(float(step)).cuda()],
amsgrad=amsgrad,
lr=lr,
beta1=beta1,
beta2=beta2,
weight_decay=weight_decay,
eps=eps,
maximize=False,
grad_scale=None,
found_inf=None,
)
[param],
[param_grad],
[exp_avg],
[exp_avg_sq],
[max_exp_avg_sq],
[torch.tensor(float(step)).cuda()],
amsgrad=amsgrad,
lr=lr,
beta1=beta1,
beta2=beta2,
weight_decay=weight_decay,
eps=eps,
maximize=False,
grad_scale=None,
found_inf=None,
)
return param, exp_avg, exp_avg_sq, max_exp_avg_sq

def adamw_(self, dtype_):
for i in range(len(self.weight_shape_list)):
weight = torch.randn(self.weight_shape_list[i], dtype=dtype_).cuda()
weight_cpu = weight.cpu().to(torch.float32) if dtype_ == torch.float16 else weight.cpu()
weight_cpu = (
weight.cpu().to(torch.float32)
if dtype_ == torch.float16
else weight.cpu()
)
grad = torch.randn(self.weight_shape_list[i], dtype=dtype_).cuda()
grad_cpu = grad.cpu().to(torch.float32) if dtype_ == torch.float16 else grad.cpu()
grad_cpu = (
grad.cpu().to(torch.float32) if dtype_ == torch.float16 else grad.cpu()
)
m = torch.randn(self.weight_shape_list[i], dtype=dtype_).cuda()
m_cpu = m.cpu().to(torch.float32) if dtype_ == torch.float16 else m.cpu()
v = torch.randn(self.weight_shape_list[i], dtype=dtype_).cuda()
v_cpu = v.cpu().to(torch.float32) if dtype_ == torch.float16 else v.cpu()
max_v = torch.randn(self.weight_shape_list[i], dtype=dtype_).cuda()
max_v_cpu = max_v.cpu().to(torch.float32) if dtype_ == torch.float16 else max_v.cpu()

max_v_cpu = (
max_v.cpu().to(torch.float32)
if dtype_ == torch.float16
else max_v.cpu()
)

lr = self.lr_list[i]
beta1 = self.beta1_list[i]
beta2 = self.beta2_list[i]
Expand All @@ -72,25 +111,90 @@ def adamw_(self, dtype_):
amsgrad = self.amsgrad_list[i]
step = self.step_list[i]

w_new_cpu, m_new_cpu, v_new_cpu, max_v_new_cpu = self.run_adamw_cpu(weight_cpu, grad_cpu, m_cpu, v_cpu, max_v_cpu, lr, beta1, beta2, eps, step, weight_decay, amsgrad)
w_new, m_new, v_new, max_v_new = self.run_adamw_dipu(weight, grad, m, v, max_v, lr, beta1, beta2, eps, step, weight_decay, amsgrad)

w_new_cpu, m_new_cpu, v_new_cpu, max_v_new_cpu = self.run_adamw_cpu(
weight_cpu,
grad_cpu,
m_cpu,
v_cpu,
max_v_cpu,
lr,
beta1,
beta2,
eps,
step,
weight_decay,
amsgrad,
)
w_new, m_new, v_new, max_v_new = self.run_adamw_dipu(
weight,
grad,
m,
v,
max_v,
lr,
beta1,
beta2,
eps,
step,
weight_decay,
amsgrad,
)

self.assertTrue(
torch.allclose(w_new.cpu(), w_new_cpu.to(torch.float16) if dtype_ == torch.float16 else w_new_cpu, atol=2e-2 if dtype_ == torch.float16 else 1e-2, rtol=4e-3 if dtype_ == torch.float16 else 2e-3, equal_nan=True),
torch.allclose(m_new.cpu(), m_new_cpu.to(torch.float16) if dtype_ == torch.float16 else m_new_cpu, atol=2e-2 if dtype_ == torch.float16 else 1e-2, rtol=4e-3 if dtype_ == torch.float16 else 2e-3, equal_nan=True),

torch.allclose(
w_new.cpu(),
(
w_new_cpu.to(torch.float16)
if dtype_ == torch.float16
else w_new_cpu
),
atol=2e-2 if dtype_ == torch.float16 else 1e-2,
rtol=4e-3 if dtype_ == torch.float16 else 2e-3,
equal_nan=True,
),
torch.allclose(
m_new.cpu(),
(
m_new_cpu.to(torch.float16)
if dtype_ == torch.float16
else m_new_cpu
),
atol=2e-2 if dtype_ == torch.float16 else 1e-2,
rtol=4e-3 if dtype_ == torch.float16 else 2e-3,
equal_nan=True,
),
)
self.assertTrue(
torch.allclose(v_new.cpu(), v_new_cpu.to(torch.float16) if dtype_ == torch.float16 else v_new_cpu, atol=2e-2 if dtype_ == torch.float16 else 1e-2, rtol=4e-3 if dtype_ == torch.float16 else 2e-3, equal_nan=True),
torch.allclose(max_v_new.cpu(), max_v_new_cpu.to(torch.float16) if dtype_ == torch.float16 else max_v_new_cpu, atol=2e-2 if dtype_ == torch.float16 else 1e-2, rtol=4e-3 if dtype_ == torch.float16 else 2e-3, equal_nan=True)
torch.allclose(
v_new.cpu(),
(
v_new_cpu.to(torch.float16)
if dtype_ == torch.float16
else v_new_cpu
),
atol=2e-2 if dtype_ == torch.float16 else 1e-2,
rtol=4e-3 if dtype_ == torch.float16 else 2e-3,
equal_nan=True,
),
torch.allclose(
max_v_new.cpu(),
(
max_v_new_cpu.to(torch.float16)
if dtype_ == torch.float16
else max_v_new_cpu
),
atol=2e-2 if dtype_ == torch.float16 else 1e-2,
rtol=4e-3 if dtype_ == torch.float16 else 2e-3,
equal_nan=True,
),
)

def test_adamw_fp16_(self):
self.adamw_(torch.float16)

def test_adamw_fp32_(self):
self.adamw_(torch.float32)


if __name__ == "__main__":
run_tests()

0 comments on commit ffcc77f

Please sign in to comment.