From aa4e91a6dc86dec2c1407678e4d35eceb91f8d2c Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Wed, 5 Aug 2020 16:05:51 -0700 Subject: [PATCH] Fix `TestSparse.test_bmm_windows_error` when CUDA is not available (#42626) Summary: Refactor comnon pattern of (torch.cuda.version and [int(x) for x in torch.cuda.version.split(".")] >= [a, b]) into `_get_torch_cuda_version()` function Pull Request resolved: https://github.com/pytorch/pytorch/pull/42626 Reviewed By: seemethere Differential Revision: D22956149 Pulled By: malfet fbshipit-source-id: 897c55965e53b477cd20f69e8da15d90489035de --- test/test_sparse.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/test/test_sparse.py b/test/test_sparse.py index bdf15f3a7616f..e989566ef37b4 100644 --- a/test/test_sparse.py +++ b/test/test_sparse.py @@ -36,6 +36,9 @@ def outer(self, *args, **kwargs): inner(self, *args, **kwargs) return outer +def _get_torch_cuda_version(): + return [int(x) for x in torch.version.cuda.split(".")] if torch.version.cuda else [0, 0] + class TestSparse(TestCase): @@ -928,9 +931,7 @@ def test_shape(di, dj, dk, nnz): "bmm sparse-dense CUDA is not yet supported in Windows, at least up to CUDA 10.1" ) @unittest.skipIf( - TEST_CUDA and ( - not torch.version.cuda - or [int(x) for x in torch.version.cuda.split(".")] < [10, 1]), + TEST_CUDA and _get_torch_cuda_version() < [10, 1], "bmm sparse-dense requires CUDA 10.1 or greater" ) def test_bmm(self): @@ -994,8 +995,7 @@ def test_shape(num_mats, dim_i, dim_j, dim_k, nnz): "bmm sparse-dense CUDA is not yet supported in Windows, at least up to CUDA 10.1" ) @unittest.skipIf( - (not torch.version.cuda - or [int(x) for x in torch.version.cuda.split(".")] < [10, 1]), + _get_torch_cuda_version() < [10, 1], "bmm sparse-dense requires CUDA 10.1 or greater" ) def test_bmm_deterministic(self): @@ -1030,7 +1030,7 @@ def test_shape(num_mats, dim_i, dim_j, dim_k, nnz): @cuda_only @unittest.skipIf( - not IS_WINDOWS or [int(x) for x in torch.version.cuda.split(".")] >= [11, 0], + not IS_WINDOWS or _get_torch_cuda_version() >= [11, 0], "this test ensures bmm sparse-dense CUDA gives an error when run on Windows with CUDA < 11.0" ) def test_bmm_windows_error(self): @@ -1044,8 +1044,7 @@ def test_bmm_windows_error(self): @cuda_only @skipIfRocm @unittest.skipIf( - (torch.version.cuda - and [int(x) for x in torch.version.cuda.split(".")] >= [10, 1]), + _get_torch_cuda_version() >= [10, 1], "this test ensures bmm gives error if CUDA version is less than 10.1" ) def test_bmm_cuda_version_error(self):