-
Notifications
You must be signed in to change notification settings - Fork 28
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat fused_adamw #938
Merged
Merged
feat fused_adamw #938
Changes from 5 commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
7333745
feat fused_adamw
zyf654321 e76a5bf
add test code
zyf654321 ffcc77f
add test code
zyf654321 cb2e125
limit only on cuda
zyf654321 1f93b60
limit only on cuda
zyf654321 9a8cf7d
Simplify code
zyf654321 a15546b
Simplify code
zyf654321 c21e538
Simplify code
zyf654321 1c97e65
Simplify code
zyf654321 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,202 @@ | ||
import torch | ||
import numpy as np | ||
from torch_dipu.testing._internal.common_utils import TestCase, run_tests, onlyOn | ||
|
||
|
||
class TestFusedAdamW(TestCase): | ||
def setUp(self): | ||
self.weight_shape_list = [(), (16,), (4, 8), (12, 4, 8)] | ||
self.lr_list = [0.001, 0.01, 0.001, 0.01] | ||
self.beta1_list = [0.9, 0.9, 0.9, 0.9] | ||
self.beta2_list = [0.999, 0.999, 0.999, 0.999] | ||
self.eps_list = [1e-8, 1e-8, 1e-8, 1e-8] | ||
self.weight_decay_list = [1e-2, 1e-3, 1e-2, 1e-3] | ||
self.amsgrad_list = [False, False, True, True] | ||
self.step_list = [2, 3, 4, 5] | ||
|
||
def run_adamw_cpu( | ||
self, | ||
param, | ||
param_grad, | ||
exp_avg, | ||
exp_avg_sq, | ||
max_exp_avg_sq, | ||
lr, | ||
beta1, | ||
beta2, | ||
eps, | ||
step, | ||
weight_decay, | ||
amsgrad, | ||
): | ||
torch.optim._functional.adamw( | ||
[param], | ||
[param_grad], | ||
[exp_avg], | ||
[exp_avg_sq], | ||
[max_exp_avg_sq], | ||
[torch.tensor(float(step))], | ||
amsgrad=amsgrad, | ||
beta1=beta1, | ||
beta2=beta2, | ||
lr=lr, | ||
weight_decay=weight_decay, | ||
eps=eps, | ||
maximize=False, | ||
) | ||
return param, exp_avg, exp_avg_sq, max_exp_avg_sq | ||
|
||
def run_adamw_dipu( | ||
self, | ||
param, | ||
param_grad, | ||
exp_avg, | ||
exp_avg_sq, | ||
max_exp_avg_sq, | ||
lr, | ||
beta1, | ||
beta2, | ||
eps, | ||
step, | ||
weight_decay, | ||
amsgrad, | ||
): | ||
torch._fused_adamw_( | ||
[param], | ||
[param_grad], | ||
[exp_avg], | ||
[exp_avg_sq], | ||
[max_exp_avg_sq], | ||
[torch.tensor(float(step)).cuda()], | ||
amsgrad=amsgrad, | ||
lr=lr, | ||
beta1=beta1, | ||
beta2=beta2, | ||
weight_decay=weight_decay, | ||
eps=eps, | ||
maximize=False, | ||
grad_scale=None, | ||
found_inf=None, | ||
) | ||
return param, exp_avg, exp_avg_sq, max_exp_avg_sq | ||
|
||
def adamw_(self, dtype_): | ||
for i in range(len(self.weight_shape_list)): | ||
weight = torch.randn(self.weight_shape_list[i], dtype=dtype_).cuda() | ||
weight_cpu = ( | ||
weight.cpu().to(torch.float32) | ||
if dtype_ == torch.float16 | ||
else weight.cpu() | ||
) | ||
grad = torch.randn(self.weight_shape_list[i], dtype=dtype_).cuda() | ||
grad_cpu = ( | ||
grad.cpu().to(torch.float32) if dtype_ == torch.float16 else grad.cpu() | ||
) | ||
m = torch.randn(self.weight_shape_list[i], dtype=dtype_).cuda() | ||
m_cpu = m.cpu().to(torch.float32) if dtype_ == torch.float16 else m.cpu() | ||
v = torch.randn(self.weight_shape_list[i], dtype=dtype_).cuda() | ||
v_cpu = v.cpu().to(torch.float32) if dtype_ == torch.float16 else v.cpu() | ||
max_v = torch.randn(self.weight_shape_list[i], dtype=dtype_).cuda() | ||
max_v_cpu = ( | ||
max_v.cpu().to(torch.float32) | ||
if dtype_ == torch.float16 | ||
else max_v.cpu() | ||
) | ||
|
||
lr = self.lr_list[i] | ||
beta1 = self.beta1_list[i] | ||
beta2 = self.beta2_list[i] | ||
eps = self.eps_list[i] | ||
weight_decay = self.weight_decay_list[i] | ||
amsgrad = self.amsgrad_list[i] | ||
step = self.step_list[i] | ||
|
||
w_new_cpu, m_new_cpu, v_new_cpu, max_v_new_cpu = self.run_adamw_cpu( | ||
weight_cpu, | ||
grad_cpu, | ||
m_cpu, | ||
v_cpu, | ||
max_v_cpu, | ||
lr, | ||
beta1, | ||
beta2, | ||
eps, | ||
step, | ||
weight_decay, | ||
amsgrad, | ||
) | ||
w_new, m_new, v_new, max_v_new = self.run_adamw_dipu( | ||
weight, | ||
grad, | ||
m, | ||
v, | ||
max_v, | ||
lr, | ||
beta1, | ||
beta2, | ||
eps, | ||
step, | ||
weight_decay, | ||
amsgrad, | ||
) | ||
|
||
self.assertTrue( | ||
torch.allclose( | ||
w_new.cpu(), | ||
( | ||
w_new_cpu.to(torch.float16) | ||
if dtype_ == torch.float16 | ||
else w_new_cpu | ||
), | ||
atol=2e-2 if dtype_ == torch.float16 else 1e-2, | ||
rtol=4e-3 if dtype_ == torch.float16 else 2e-3, | ||
equal_nan=True, | ||
), | ||
torch.allclose( | ||
m_new.cpu(), | ||
( | ||
m_new_cpu.to(torch.float16) | ||
if dtype_ == torch.float16 | ||
else m_new_cpu | ||
), | ||
atol=2e-2 if dtype_ == torch.float16 else 1e-2, | ||
rtol=4e-3 if dtype_ == torch.float16 else 2e-3, | ||
equal_nan=True, | ||
), | ||
) | ||
self.assertTrue( | ||
torch.allclose( | ||
v_new.cpu(), | ||
( | ||
v_new_cpu.to(torch.float16) | ||
if dtype_ == torch.float16 | ||
else v_new_cpu | ||
), | ||
atol=2e-2 if dtype_ == torch.float16 else 1e-2, | ||
rtol=4e-3 if dtype_ == torch.float16 else 2e-3, | ||
equal_nan=True, | ||
), | ||
torch.allclose( | ||
max_v_new.cpu(), | ||
( | ||
max_v_new_cpu.to(torch.float16) | ||
if dtype_ == torch.float16 | ||
else max_v_new_cpu | ||
), | ||
atol=2e-2 if dtype_ == torch.float16 else 1e-2, | ||
rtol=4e-3 if dtype_ == torch.float16 else 2e-3, | ||
equal_nan=True, | ||
), | ||
) | ||
|
||
@onlyOn("CUDA") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 如果期望所有芯片都实现,应该用 skipon 比较好 |
||
def test_adamw_fp16_(self): | ||
self.adamw_(torch.float16) | ||
|
||
@onlyOn("CUDA") | ||
def test_adamw_fp32_(self): | ||
self.adamw_(torch.float32) | ||
|
||
|
||
if __name__ == "__main__": | ||
run_tests() |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
削减重复代码,考虑:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
toDiopiTensorHandleVector
已经有这个函数了