From 73337452802e1991cf9a69b6e134c5aa5359cc27 Mon Sep 17 00:00:00 2001 From: zyf654321 Date: Thu, 5 Sep 2024 14:50:28 +0800 Subject: [PATCH 1/9] feat fused_adamw --- dipu/SupportedDiopiFunctions.txt | 1 + .../diopi_functions.yaml | 36 +++++++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/dipu/SupportedDiopiFunctions.txt b/dipu/SupportedDiopiFunctions.txt index 547e75955..0f6a3b0ff 100644 --- a/dipu/SupportedDiopiFunctions.txt +++ b/dipu/SupportedDiopiFunctions.txt @@ -101,6 +101,7 @@ diopiForeachmulInpTensor diopiForeachmulScalar diopiForeachmulTensor diopiForeachnormScalar +diopiFusedAdamW diopiGather diopiGe diopiGeInp diff --git a/dipu/scripts/autogen_diopi_wrapper/diopi_functions.yaml b/dipu/scripts/autogen_diopi_wrapper/diopi_functions.yaml index 2759f7fb6..8e3b03ff0 100755 --- a/dipu/scripts/autogen_diopi_wrapper/diopi_functions.yaml +++ b/dipu/scripts/autogen_diopi_wrapper/diopi_functions.yaml @@ -1325,6 +1325,42 @@ ::diopiConstTensorHandle_t self_dtype_diopi = dipu::diopi_helper::toDiopiTensorHandle(self_dtype); interface: diopiProd(ctx, out, self_dtype_diopi, nullptr) +- schema: "_fused_adamw_(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, float lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> ()" + custom_code_at_the_beginning: | + std::vector diopiTensorHandles_self(self.size()); + for(size_t i=0; i < self.size(); ++i){ + diopiConstTensorHandle_t const_handle = dipu::diopi_helper::toDiopiTensorHandle(self.at(i)); + diopiTensorHandle_t handle = const_cast(const_handle); + diopiTensorHandles_self[i] = handle; + } + std::vector diopiTensorHandles_grads(grads.size()); + for(size_t i=0; i < grads.size(); ++i){ + diopiTensorHandles_grads[i] = dipu::diopi_helper::toDiopiTensorHandle(grads.at(i)); + } + std::vector diopiTensorHandles_exp_avgs(exp_avgs.size()); + for(size_t i=0; i < exp_avgs.size(); ++i){ + diopiConstTensorHandle_t const_handle = dipu::diopi_helper::toDiopiTensorHandle(exp_avgs.at(i)); + diopiTensorHandle_t handle = const_cast(const_handle); + diopiTensorHandles_exp_avgs[i] = handle; + } + std::vector diopiTensorHandles_exp_avg_sqs(exp_avg_sqs.size()); + for(size_t i=0; i < exp_avg_sqs.size(); ++i){ + diopiConstTensorHandle_t const_handle = dipu::diopi_helper::toDiopiTensorHandle(exp_avg_sqs.at(i)); + diopiTensorHandle_t handle = const_cast(const_handle); + diopiTensorHandles_exp_avg_sqs[i] = handle; + } + std::vector diopiTensorHandles_max_exp_avg_sqs(max_exp_avg_sqs.size()); + for(size_t i=0; i < max_exp_avg_sqs.size(); ++i){ + diopiConstTensorHandle_t const_handle = dipu::diopi_helper::toDiopiTensorHandle(max_exp_avg_sqs.at(i)); + diopiTensorHandle_t handle = const_cast(const_handle); + diopiTensorHandles_max_exp_avg_sqs[i] = handle; + } + std::vector diopiTensorHandles_state_steps(state_steps.size(), nullptr); + for(size_t i=0; i < state_steps.size(); ++i){ + diopiTensorHandles_state_steps[i] = dipu::diopi_helper::toDiopiTensorHandle(state_steps.at(i)); + } + interface: diopiFusedAdamW(ctx, diopiTensorHandles_self.data(), diopiTensorHandles_grads.data(), diopiTensorHandles_exp_avgs.data(), diopiTensorHandles_exp_avg_sqs.data(), diopiTensorHandles_max_exp_avg_sqs.data(), diopiTensorHandles_state_steps.data(), static_cast(self.size()), lr, beta1, beta2, eps, weight_decay, amsgrad, maximize) + - schema: prod.int_out(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) custom_code_at_the_beginning: | const auto self_dtype = at::native::to(self, dtype); From e76a5bfbe0d6d7a8d8fbeba3e49fa86d1ec5bc2f Mon Sep 17 00:00:00 2001 From: zyf654321 Date: Fri, 20 Sep 2024 17:11:44 +0800 Subject: [PATCH 2/9] add test code --- dipu/tests/python/unittests/test_adamw.py | 96 +++++++++++++++++++++++ 1 file changed, 96 insertions(+) create mode 100644 dipu/tests/python/unittests/test_adamw.py diff --git a/dipu/tests/python/unittests/test_adamw.py b/dipu/tests/python/unittests/test_adamw.py new file mode 100644 index 000000000..c76b0427e --- /dev/null +++ b/dipu/tests/python/unittests/test_adamw.py @@ -0,0 +1,96 @@ +import torch +import numpy as np +from torch_dipu.testing._internal.common_utils import TestCase, run_tests + +class TestFusedAdamW(TestCase): + def setUp(self): + self.weight_shape_list = [(), (16,), (4, 8), (12, 4, 8)] + self.lr_list = [0.001, 0.01, 0.001, 0.01] + self.beta1_list = [0.9, 0.9, 0.9, 0.9] + self.beta2_list = [0.999, 0.999, 0.999, 0.999] + self.eps_list = [1e-8, 1e-8, 1e-8, 1e-8] + self.weight_decay_list = [1e-2, 1e-3, 1e-2, 1e-3] + self.amsgrad_list = [False, False, True, True] + self.step_list = [2, 3, 4, 5] + + def run_adamw_cpu(self, param, param_grad, exp_avg, exp_avg_sq, max_exp_avg_sq, lr, beta1, beta2, eps, step, weight_decay, amsgrad): + torch.optim._functional.adamw( + [param], + [param_grad], + [exp_avg], + [exp_avg_sq], + [max_exp_avg_sq], + [torch.tensor(float(step))], + amsgrad=amsgrad, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=False, + ) + return param, exp_avg, exp_avg_sq, max_exp_avg_sq + + def run_adamw_dipu(self, param, param_grad, exp_avg, exp_avg_sq, max_exp_avg_sq, lr, beta1, beta2, eps, step, weight_decay, amsgrad): + torch._fused_adamw_( + [param], + [param_grad], + [exp_avg], + [exp_avg_sq], + [max_exp_avg_sq], + [torch.tensor(float(step)).cuda()], + amsgrad=amsgrad, + lr=lr, + beta1=beta1, + beta2=beta2, + weight_decay=weight_decay, + eps=eps, + maximize=False, + grad_scale=None, + found_inf=None, + ) + return param, exp_avg, exp_avg_sq, max_exp_avg_sq + + def adamw_(self, dtype_): + for i in range(len(self.weight_shape_list)): + weight = torch.randn(self.weight_shape_list[i], dtype=dtype_).cuda() + weight_cpu = weight.cpu().to(torch.float32) if dtype_ == torch.float16 else weight.cpu() + grad = torch.randn(self.weight_shape_list[i], dtype=dtype_).cuda() + grad_cpu = grad.cpu().to(torch.float32) if dtype_ == torch.float16 else grad.cpu() + m = torch.randn(self.weight_shape_list[i], dtype=dtype_).cuda() + m_cpu = m.cpu().to(torch.float32) if dtype_ == torch.float16 else m.cpu() + v = torch.randn(self.weight_shape_list[i], dtype=dtype_).cuda() + v_cpu = v.cpu().to(torch.float32) if dtype_ == torch.float16 else v.cpu() + max_v = torch.randn(self.weight_shape_list[i], dtype=dtype_).cuda() + max_v_cpu = max_v.cpu().to(torch.float32) if dtype_ == torch.float16 else max_v.cpu() + + lr = self.lr_list[i] + beta1 = self.beta1_list[i] + beta2 = self.beta2_list[i] + eps = self.eps_list[i] + weight_decay = self.weight_decay_list[i] + amsgrad = self.amsgrad_list[i] + step = self.step_list[i] + + w_new_cpu, m_new_cpu, v_new_cpu, max_v_new_cpu = self.run_adamw_cpu(weight_cpu, grad_cpu, m_cpu, v_cpu, max_v_cpu, lr, beta1, beta2, eps, step, weight_decay, amsgrad) + w_new, m_new, v_new, max_v_new = self.run_adamw_dipu(weight, grad, m, v, max_v, lr, beta1, beta2, eps, step, weight_decay, amsgrad) + + self.assertTrue( + torch.allclose(w_new.cpu(), w_new_cpu.to(torch.float16) if dtype_ == torch.float16 else w_new_cpu, atol=2e-2 if dtype_ == torch.float16 else 1e-2, rtol=4e-3 if dtype_ == torch.float16 else 2e-3, equal_nan=True), + torch.allclose(m_new.cpu(), m_new_cpu.to(torch.float16) if dtype_ == torch.float16 else m_new_cpu, atol=2e-2 if dtype_ == torch.float16 else 1e-2, rtol=4e-3 if dtype_ == torch.float16 else 2e-3, equal_nan=True), + + ) + self.assertTrue( + torch.allclose(v_new.cpu(), v_new_cpu.to(torch.float16) if dtype_ == torch.float16 else v_new_cpu, atol=2e-2 if dtype_ == torch.float16 else 1e-2, rtol=4e-3 if dtype_ == torch.float16 else 2e-3, equal_nan=True), + torch.allclose(max_v_new.cpu(), max_v_new_cpu.to(torch.float16) if dtype_ == torch.float16 else max_v_new_cpu, atol=2e-2 if dtype_ == torch.float16 else 1e-2, rtol=4e-3 if dtype_ == torch.float16 else 2e-3, equal_nan=True) + ) + + def test_adamw_fp16_(self): + self.adamw_(torch.float16) + + def test_adamw_fp32_(self): + self.adamw_(torch.float32) + +if __name__ == "__main__": + run_tests() + From ffcc77f70f66a0132cf577d137069d47ae1bbf3a Mon Sep 17 00:00:00 2001 From: zyf654321 Date: Fri, 20 Sep 2024 17:38:51 +0800 Subject: [PATCH 3/9] add test code --- dipu/tests/python/unittests/test_adamw.py | 196 +++++++++++++++++----- 1 file changed, 150 insertions(+), 46 deletions(-) diff --git a/dipu/tests/python/unittests/test_adamw.py b/dipu/tests/python/unittests/test_adamw.py index c76b0427e..d0f9b9fc9 100644 --- a/dipu/tests/python/unittests/test_adamw.py +++ b/dipu/tests/python/unittests/test_adamw.py @@ -2,6 +2,7 @@ import numpy as np from torch_dipu.testing._internal.common_utils import TestCase, run_tests + class TestFusedAdamW(TestCase): def setUp(self): self.weight_shape_list = [(), (16,), (4, 8), (12, 4, 8)] @@ -13,57 +14,95 @@ def setUp(self): self.amsgrad_list = [False, False, True, True] self.step_list = [2, 3, 4, 5] - def run_adamw_cpu(self, param, param_grad, exp_avg, exp_avg_sq, max_exp_avg_sq, lr, beta1, beta2, eps, step, weight_decay, amsgrad): + def run_adamw_cpu( + self, + param, + param_grad, + exp_avg, + exp_avg_sq, + max_exp_avg_sq, + lr, + beta1, + beta2, + eps, + step, + weight_decay, + amsgrad, + ): torch.optim._functional.adamw( - [param], - [param_grad], - [exp_avg], - [exp_avg_sq], - [max_exp_avg_sq], - [torch.tensor(float(step))], - amsgrad=amsgrad, - beta1=beta1, - beta2=beta2, - lr=lr, - weight_decay=weight_decay, - eps=eps, - maximize=False, - ) + [param], + [param_grad], + [exp_avg], + [exp_avg_sq], + [max_exp_avg_sq], + [torch.tensor(float(step))], + amsgrad=amsgrad, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=False, + ) return param, exp_avg, exp_avg_sq, max_exp_avg_sq - def run_adamw_dipu(self, param, param_grad, exp_avg, exp_avg_sq, max_exp_avg_sq, lr, beta1, beta2, eps, step, weight_decay, amsgrad): + def run_adamw_dipu( + self, + param, + param_grad, + exp_avg, + exp_avg_sq, + max_exp_avg_sq, + lr, + beta1, + beta2, + eps, + step, + weight_decay, + amsgrad, + ): torch._fused_adamw_( - [param], - [param_grad], - [exp_avg], - [exp_avg_sq], - [max_exp_avg_sq], - [torch.tensor(float(step)).cuda()], - amsgrad=amsgrad, - lr=lr, - beta1=beta1, - beta2=beta2, - weight_decay=weight_decay, - eps=eps, - maximize=False, - grad_scale=None, - found_inf=None, - ) + [param], + [param_grad], + [exp_avg], + [exp_avg_sq], + [max_exp_avg_sq], + [torch.tensor(float(step)).cuda()], + amsgrad=amsgrad, + lr=lr, + beta1=beta1, + beta2=beta2, + weight_decay=weight_decay, + eps=eps, + maximize=False, + grad_scale=None, + found_inf=None, + ) return param, exp_avg, exp_avg_sq, max_exp_avg_sq def adamw_(self, dtype_): for i in range(len(self.weight_shape_list)): weight = torch.randn(self.weight_shape_list[i], dtype=dtype_).cuda() - weight_cpu = weight.cpu().to(torch.float32) if dtype_ == torch.float16 else weight.cpu() + weight_cpu = ( + weight.cpu().to(torch.float32) + if dtype_ == torch.float16 + else weight.cpu() + ) grad = torch.randn(self.weight_shape_list[i], dtype=dtype_).cuda() - grad_cpu = grad.cpu().to(torch.float32) if dtype_ == torch.float16 else grad.cpu() + grad_cpu = ( + grad.cpu().to(torch.float32) if dtype_ == torch.float16 else grad.cpu() + ) m = torch.randn(self.weight_shape_list[i], dtype=dtype_).cuda() m_cpu = m.cpu().to(torch.float32) if dtype_ == torch.float16 else m.cpu() v = torch.randn(self.weight_shape_list[i], dtype=dtype_).cuda() v_cpu = v.cpu().to(torch.float32) if dtype_ == torch.float16 else v.cpu() max_v = torch.randn(self.weight_shape_list[i], dtype=dtype_).cuda() - max_v_cpu = max_v.cpu().to(torch.float32) if dtype_ == torch.float16 else max_v.cpu() - + max_v_cpu = ( + max_v.cpu().to(torch.float32) + if dtype_ == torch.float16 + else max_v.cpu() + ) + lr = self.lr_list[i] beta1 = self.beta1_list[i] beta2 = self.beta2_list[i] @@ -72,25 +111,90 @@ def adamw_(self, dtype_): amsgrad = self.amsgrad_list[i] step = self.step_list[i] - w_new_cpu, m_new_cpu, v_new_cpu, max_v_new_cpu = self.run_adamw_cpu(weight_cpu, grad_cpu, m_cpu, v_cpu, max_v_cpu, lr, beta1, beta2, eps, step, weight_decay, amsgrad) - w_new, m_new, v_new, max_v_new = self.run_adamw_dipu(weight, grad, m, v, max_v, lr, beta1, beta2, eps, step, weight_decay, amsgrad) - + w_new_cpu, m_new_cpu, v_new_cpu, max_v_new_cpu = self.run_adamw_cpu( + weight_cpu, + grad_cpu, + m_cpu, + v_cpu, + max_v_cpu, + lr, + beta1, + beta2, + eps, + step, + weight_decay, + amsgrad, + ) + w_new, m_new, v_new, max_v_new = self.run_adamw_dipu( + weight, + grad, + m, + v, + max_v, + lr, + beta1, + beta2, + eps, + step, + weight_decay, + amsgrad, + ) + self.assertTrue( - torch.allclose(w_new.cpu(), w_new_cpu.to(torch.float16) if dtype_ == torch.float16 else w_new_cpu, atol=2e-2 if dtype_ == torch.float16 else 1e-2, rtol=4e-3 if dtype_ == torch.float16 else 2e-3, equal_nan=True), - torch.allclose(m_new.cpu(), m_new_cpu.to(torch.float16) if dtype_ == torch.float16 else m_new_cpu, atol=2e-2 if dtype_ == torch.float16 else 1e-2, rtol=4e-3 if dtype_ == torch.float16 else 2e-3, equal_nan=True), - + torch.allclose( + w_new.cpu(), + ( + w_new_cpu.to(torch.float16) + if dtype_ == torch.float16 + else w_new_cpu + ), + atol=2e-2 if dtype_ == torch.float16 else 1e-2, + rtol=4e-3 if dtype_ == torch.float16 else 2e-3, + equal_nan=True, + ), + torch.allclose( + m_new.cpu(), + ( + m_new_cpu.to(torch.float16) + if dtype_ == torch.float16 + else m_new_cpu + ), + atol=2e-2 if dtype_ == torch.float16 else 1e-2, + rtol=4e-3 if dtype_ == torch.float16 else 2e-3, + equal_nan=True, + ), ) self.assertTrue( - torch.allclose(v_new.cpu(), v_new_cpu.to(torch.float16) if dtype_ == torch.float16 else v_new_cpu, atol=2e-2 if dtype_ == torch.float16 else 1e-2, rtol=4e-3 if dtype_ == torch.float16 else 2e-3, equal_nan=True), - torch.allclose(max_v_new.cpu(), max_v_new_cpu.to(torch.float16) if dtype_ == torch.float16 else max_v_new_cpu, atol=2e-2 if dtype_ == torch.float16 else 1e-2, rtol=4e-3 if dtype_ == torch.float16 else 2e-3, equal_nan=True) + torch.allclose( + v_new.cpu(), + ( + v_new_cpu.to(torch.float16) + if dtype_ == torch.float16 + else v_new_cpu + ), + atol=2e-2 if dtype_ == torch.float16 else 1e-2, + rtol=4e-3 if dtype_ == torch.float16 else 2e-3, + equal_nan=True, + ), + torch.allclose( + max_v_new.cpu(), + ( + max_v_new_cpu.to(torch.float16) + if dtype_ == torch.float16 + else max_v_new_cpu + ), + atol=2e-2 if dtype_ == torch.float16 else 1e-2, + rtol=4e-3 if dtype_ == torch.float16 else 2e-3, + equal_nan=True, + ), ) - + def test_adamw_fp16_(self): self.adamw_(torch.float16) def test_adamw_fp32_(self): self.adamw_(torch.float32) + if __name__ == "__main__": run_tests() - From cb2e12563bd1e9bcbe7fe43d31fea6afd811d98c Mon Sep 17 00:00:00 2001 From: zyf654321 Date: Tue, 24 Sep 2024 16:01:55 +0800 Subject: [PATCH 4/9] limit only on cuda --- dipu/tests/python/unittests/test_adamw.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/dipu/tests/python/unittests/test_adamw.py b/dipu/tests/python/unittests/test_adamw.py index d0f9b9fc9..c9cdced87 100644 --- a/dipu/tests/python/unittests/test_adamw.py +++ b/dipu/tests/python/unittests/test_adamw.py @@ -189,9 +189,11 @@ def adamw_(self, dtype_): ), ) + @onlyOn("CUDA") def test_adamw_fp16_(self): self.adamw_(torch.float16) + @onlyOn("CUDA") def test_adamw_fp32_(self): self.adamw_(torch.float32) From 1f93b60bd72cfd6e38a9f8be25545cee2dfb973b Mon Sep 17 00:00:00 2001 From: zyf654321 Date: Tue, 24 Sep 2024 17:37:27 +0800 Subject: [PATCH 5/9] limit only on cuda --- dipu/tests/python/unittests/test_adamw.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dipu/tests/python/unittests/test_adamw.py b/dipu/tests/python/unittests/test_adamw.py index c9cdced87..b04125653 100644 --- a/dipu/tests/python/unittests/test_adamw.py +++ b/dipu/tests/python/unittests/test_adamw.py @@ -1,6 +1,6 @@ import torch import numpy as np -from torch_dipu.testing._internal.common_utils import TestCase, run_tests +from torch_dipu.testing._internal.common_utils import TestCase, run_tests, onlyOn class TestFusedAdamW(TestCase): From 9a8cf7d42e2d9d0726dde3a4f8a084598d190767 Mon Sep 17 00:00:00 2001 From: zyf654321 Date: Thu, 10 Oct 2024 11:24:05 +0800 Subject: [PATCH 6/9] Simplify code --- .../diopi_functions.yaml | 38 +++---------------- dipu/tests/python/unittests/test_adamw.py | 2 +- 2 files changed, 7 insertions(+), 33 deletions(-) diff --git a/dipu/scripts/autogen_diopi_wrapper/diopi_functions.yaml b/dipu/scripts/autogen_diopi_wrapper/diopi_functions.yaml index 8e3b03ff0..1e49dd1aa 100755 --- a/dipu/scripts/autogen_diopi_wrapper/diopi_functions.yaml +++ b/dipu/scripts/autogen_diopi_wrapper/diopi_functions.yaml @@ -1327,38 +1327,12 @@ - schema: "_fused_adamw_(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, float lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> ()" custom_code_at_the_beginning: | - std::vector diopiTensorHandles_self(self.size()); - for(size_t i=0; i < self.size(); ++i){ - diopiConstTensorHandle_t const_handle = dipu::diopi_helper::toDiopiTensorHandle(self.at(i)); - diopiTensorHandle_t handle = const_cast(const_handle); - diopiTensorHandles_self[i] = handle; - } - std::vector diopiTensorHandles_grads(grads.size()); - for(size_t i=0; i < grads.size(); ++i){ - diopiTensorHandles_grads[i] = dipu::diopi_helper::toDiopiTensorHandle(grads.at(i)); - } - std::vector diopiTensorHandles_exp_avgs(exp_avgs.size()); - for(size_t i=0; i < exp_avgs.size(); ++i){ - diopiConstTensorHandle_t const_handle = dipu::diopi_helper::toDiopiTensorHandle(exp_avgs.at(i)); - diopiTensorHandle_t handle = const_cast(const_handle); - diopiTensorHandles_exp_avgs[i] = handle; - } - std::vector diopiTensorHandles_exp_avg_sqs(exp_avg_sqs.size()); - for(size_t i=0; i < exp_avg_sqs.size(); ++i){ - diopiConstTensorHandle_t const_handle = dipu::diopi_helper::toDiopiTensorHandle(exp_avg_sqs.at(i)); - diopiTensorHandle_t handle = const_cast(const_handle); - diopiTensorHandles_exp_avg_sqs[i] = handle; - } - std::vector diopiTensorHandles_max_exp_avg_sqs(max_exp_avg_sqs.size()); - for(size_t i=0; i < max_exp_avg_sqs.size(); ++i){ - diopiConstTensorHandle_t const_handle = dipu::diopi_helper::toDiopiTensorHandle(max_exp_avg_sqs.at(i)); - diopiTensorHandle_t handle = const_cast(const_handle); - diopiTensorHandles_max_exp_avg_sqs[i] = handle; - } - std::vector diopiTensorHandles_state_steps(state_steps.size(), nullptr); - for(size_t i=0; i < state_steps.size(); ++i){ - diopiTensorHandles_state_steps[i] = dipu::diopi_helper::toDiopiTensorHandle(state_steps.at(i)); - } + auto diopiTensorHandles_self = dipu::diopi_helper::toDiopiTensorHandleVector(self); + auto diopiTensorHandles_grads = dipu::diopi_helper::toDiopiConstTensorHandleVector(grads); + auto diopiTensorHandles_exp_avgs = dipu::diopi_helper::toDiopiTensorHandleVector(exp_avgs); + auto diopiTensorHandles_exp_avg_sqs = dipu::diopi_helper::toDiopiTensorHandleVector(exp_avg_sqs); + auto diopiTensorHandles_max_exp_avg_sqs = dipu::diopi_helper::toDiopiTensorHandleVector(max_exp_avg_sqs); + auto diopiTensorHandles_state_steps = dipu::diopi_helper::toDiopiConstTensorHandleVector(state_steps); interface: diopiFusedAdamW(ctx, diopiTensorHandles_self.data(), diopiTensorHandles_grads.data(), diopiTensorHandles_exp_avgs.data(), diopiTensorHandles_exp_avg_sqs.data(), diopiTensorHandles_max_exp_avg_sqs.data(), diopiTensorHandles_state_steps.data(), static_cast(self.size()), lr, beta1, beta2, eps, weight_decay, amsgrad, maximize) - schema: prod.int_out(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) diff --git a/dipu/tests/python/unittests/test_adamw.py b/dipu/tests/python/unittests/test_adamw.py index b04125653..1483f28c4 100644 --- a/dipu/tests/python/unittests/test_adamw.py +++ b/dipu/tests/python/unittests/test_adamw.py @@ -1,6 +1,6 @@ import torch import numpy as np -from torch_dipu.testing._internal.common_utils import TestCase, run_tests, onlyOn +from torch_dipu.testing._internal.common_utils import TestCase, run_tests, onlyOn, skipOn class TestFusedAdamW(TestCase): From a15546b965ab239e117505eda984b0d62ee60dd9 Mon Sep 17 00:00:00 2001 From: zyf654321 Date: Thu, 10 Oct 2024 11:29:19 +0800 Subject: [PATCH 7/9] Simplify code --- dipu/tests/python/unittests/test_adamw.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/dipu/tests/python/unittests/test_adamw.py b/dipu/tests/python/unittests/test_adamw.py index 1483f28c4..ffc40abff 100644 --- a/dipu/tests/python/unittests/test_adamw.py +++ b/dipu/tests/python/unittests/test_adamw.py @@ -1,6 +1,11 @@ import torch import numpy as np -from torch_dipu.testing._internal.common_utils import TestCase, run_tests, onlyOn, skipOn +from torch_dipu.testing._internal.common_utils import ( + TestCase, + run_tests, + onlyOn, + skipOn, +) class TestFusedAdamW(TestCase): From c21e538994f5646f519a3cb4ed5e683284303437 Mon Sep 17 00:00:00 2001 From: zyf654321 Date: Thu, 10 Oct 2024 12:09:10 +0800 Subject: [PATCH 8/9] Simplify code --- dipu/tests/python/unittests/test_adamw.py | 10 ++++++++-- .../testing/_internal/common_utils.py | 19 +++++++++++++++++-- 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/dipu/tests/python/unittests/test_adamw.py b/dipu/tests/python/unittests/test_adamw.py index ffc40abff..ac57925a4 100644 --- a/dipu/tests/python/unittests/test_adamw.py +++ b/dipu/tests/python/unittests/test_adamw.py @@ -194,11 +194,17 @@ def adamw_(self, dtype_): ), ) - @onlyOn("CUDA") + @skipOn( + ["MLU", "NPU", "MUXI", "GCU", "DROPLET", "SUPA", "KLX"], + "Currently, testing is only supported on CUDA", + ) def test_adamw_fp16_(self): self.adamw_(torch.float16) - @onlyOn("CUDA") + @skipOn( + ["MLU", "NPU", "MUXI", "GCU", "DROPLET", "SUPA", "KLX"], + "Currently, testing is only supported on CUDA", + ) def test_adamw_fp32_(self): self.adamw_(torch.float32) diff --git a/dipu/torch_dipu/testing/_internal/common_utils.py b/dipu/torch_dipu/testing/_internal/common_utils.py index 4838b1416..2b1897bdb 100644 --- a/dipu/torch_dipu/testing/_internal/common_utils.py +++ b/dipu/torch_dipu/testing/_internal/common_utils.py @@ -65,8 +65,23 @@ def skipOnTorchVer(torchVer: str, reason: str = ""): return unittest.skipIf(torch_dipu.dipu.get_dipu_torch_version() == torchVer, reason) -def skipOn(vendor: str, reason: str): - return unittest.skipIf(torch_dipu.dipu.vendor_type == vendor, reason) +@overload +def skipOn(vendor: str, reason: str): ... + +@overload +def skipOn(vendor: List[str], reason: str): ... + + +def skipOn(vendor, reason: str): + if isinstance(vendor, str): + vendor_list = [vendor] + else: + vendor_list = vendor + return unittest.skipIf(torch_dipu.dipu.vendor_type in vendor_list, reason) + +# def skipOn(vendor: str, reason: str): +# return unittest.skipIf(torch_dipu.dipu.vendor_type == vendor, reason) + def skipIfDevcieCountLessThan(number_of_devices_required): From 1c97e654fdef6bdacc04aa2a55d4813d09f4aaf9 Mon Sep 17 00:00:00 2001 From: zyf654321 Date: Thu, 10 Oct 2024 17:02:01 +0800 Subject: [PATCH 9/9] Simplify code --- dipu/tests/python/unittests/test_adamw.py | 4 ++-- dipu/torch_dipu/testing/_internal/common_utils.py | 12 +++++++----- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/dipu/tests/python/unittests/test_adamw.py b/dipu/tests/python/unittests/test_adamw.py index ac57925a4..69ea9d495 100644 --- a/dipu/tests/python/unittests/test_adamw.py +++ b/dipu/tests/python/unittests/test_adamw.py @@ -196,14 +196,14 @@ def adamw_(self, dtype_): @skipOn( ["MLU", "NPU", "MUXI", "GCU", "DROPLET", "SUPA", "KLX"], - "Currently, testing is only supported on CUDA", + "The adamw fusion operator has not yet been connected to the dipu of these chips, and the chip name can be removed from the above list after being added later", ) def test_adamw_fp16_(self): self.adamw_(torch.float16) @skipOn( ["MLU", "NPU", "MUXI", "GCU", "DROPLET", "SUPA", "KLX"], - "Currently, testing is only supported on CUDA", + "The adamw fusion operator has not yet been connected to the dipu of these chips, and the chip name can be removed from the above list after being added later", ) def test_adamw_fp32_(self): self.adamw_(torch.float32) diff --git a/dipu/torch_dipu/testing/_internal/common_utils.py b/dipu/torch_dipu/testing/_internal/common_utils.py index 2b1897bdb..7e0a5fe2b 100644 --- a/dipu/torch_dipu/testing/_internal/common_utils.py +++ b/dipu/torch_dipu/testing/_internal/common_utils.py @@ -68,6 +68,7 @@ def skipOnTorchVer(torchVer: str, reason: str = ""): @overload def skipOn(vendor: str, reason: str): ... + @overload def skipOn(vendor: List[str], reason: str): ... @@ -77,11 +78,12 @@ def skipOn(vendor, reason: str): vendor_list = [vendor] else: vendor_list = vendor - return unittest.skipIf(torch_dipu.dipu.vendor_type in vendor_list, reason) - -# def skipOn(vendor: str, reason: str): -# return unittest.skipIf(torch_dipu.dipu.vendor_type == vendor, reason) - + return unittest.skipIf( + torch_dipu.dipu.vendor_type in vendor_list, + "skip on {} because {}".format( + vendor[0] if len(vendor) == 1 else vendor, reason + ), + ) def skipIfDevcieCountLessThan(number_of_devices_required):