Skip to content

Commit

Permalink
Fix TestSparse.test_bmm_windows_error when CUDA is not available (p…
Browse files Browse the repository at this point in the history
…ytorch#42626)

Summary:
Refactor comnon pattern of (torch.cuda.version and [int(x) for x in torch.cuda.version.split(".")] >= [a, b]) into `_get_torch_cuda_version()` function

Pull Request resolved: pytorch#42626

Reviewed By: seemethere

Differential Revision: D22956149

Pulled By: malfet

fbshipit-source-id: 897c55965e53b477cd20f69e8da15d90489035de
  • Loading branch information
malfet authored and facebook-github-bot committed Aug 5, 2020
1 parent 5023995 commit aa4e91a
Showing 1 changed file with 7 additions and 8 deletions.
15 changes: 7 additions & 8 deletions test/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ def outer(self, *args, **kwargs):
inner(self, *args, **kwargs)
return outer

def _get_torch_cuda_version():
return [int(x) for x in torch.version.cuda.split(".")] if torch.version.cuda else [0, 0]


class TestSparse(TestCase):

Expand Down Expand Up @@ -928,9 +931,7 @@ def test_shape(di, dj, dk, nnz):
"bmm sparse-dense CUDA is not yet supported in Windows, at least up to CUDA 10.1"
)
@unittest.skipIf(
TEST_CUDA and (
not torch.version.cuda
or [int(x) for x in torch.version.cuda.split(".")] < [10, 1]),
TEST_CUDA and _get_torch_cuda_version() < [10, 1],
"bmm sparse-dense requires CUDA 10.1 or greater"
)
def test_bmm(self):
Expand Down Expand Up @@ -994,8 +995,7 @@ def test_shape(num_mats, dim_i, dim_j, dim_k, nnz):
"bmm sparse-dense CUDA is not yet supported in Windows, at least up to CUDA 10.1"
)
@unittest.skipIf(
(not torch.version.cuda
or [int(x) for x in torch.version.cuda.split(".")] < [10, 1]),
_get_torch_cuda_version() < [10, 1],
"bmm sparse-dense requires CUDA 10.1 or greater"
)
def test_bmm_deterministic(self):
Expand Down Expand Up @@ -1030,7 +1030,7 @@ def test_shape(num_mats, dim_i, dim_j, dim_k, nnz):

@cuda_only
@unittest.skipIf(
not IS_WINDOWS or [int(x) for x in torch.version.cuda.split(".")] >= [11, 0],
not IS_WINDOWS or _get_torch_cuda_version() >= [11, 0],
"this test ensures bmm sparse-dense CUDA gives an error when run on Windows with CUDA < 11.0"
)
def test_bmm_windows_error(self):
Expand All @@ -1044,8 +1044,7 @@ def test_bmm_windows_error(self):
@cuda_only
@skipIfRocm
@unittest.skipIf(
(torch.version.cuda
and [int(x) for x in torch.version.cuda.split(".")] >= [10, 1]),
_get_torch_cuda_version() >= [10, 1],
"this test ensures bmm gives error if CUDA version is less than 10.1"
)
def test_bmm_cuda_version_error(self):
Expand Down

0 comments on commit aa4e91a

Please sign in to comment.