diff --git a/dipu/tests/python/unittests/test_adamw.py b/dipu/tests/python/unittests/test_adamw.py index d0f9b9fc9..c9cdced87 100644 --- a/dipu/tests/python/unittests/test_adamw.py +++ b/dipu/tests/python/unittests/test_adamw.py @@ -189,9 +189,11 @@ def adamw_(self, dtype_): ), ) + @onlyOn("CUDA") def test_adamw_fp16_(self): self.adamw_(torch.float16) + @onlyOn("CUDA") def test_adamw_fp32_(self): self.adamw_(torch.float32)