Skip to content

Commit 9145951

Browse files
author
pytorchbot
committed
2025-01-18 nightly release (9dfdfb8)
1 parent c9cfe23 commit 9145951

File tree

1 file changed

+15
-3
lines changed

1 file changed

+15
-3
lines changed

torchrec/sparse/tests/test_tensor_dict.py

+15-3
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,20 @@
1111
import unittest
1212

1313
import torch
14+
from hypothesis import given, settings, strategies as st, Verbosity
1415
from tensordict import TensorDict
1516
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
1617
from torchrec.sparse.tensor_dict import maybe_td_to_kjt
17-
from torchrec.sparse.tests.utils import repeat_test
1818

1919

2020
class TestTensorDIct(unittest.TestCase):
21-
@repeat_test(device_str=["cpu", "cuda", "meta"])
21+
@given(device_str=st.sampled_from(["cpu", "cuda", "meta"]))
22+
@settings(verbosity=Verbosity.verbose, max_examples=5, deadline=None)
23+
# pyre-ignore[56]
24+
@unittest.skipIf(
25+
torch.cuda.device_count() <= 0,
26+
"CUDA is not available",
27+
)
2228
def test_kjt_input(self, device_str: str) -> None:
2329
device = torch.device(device_str)
2430
values = torch.tensor([0, 1, 2, 3, 2, 3, 4], device=device)
@@ -30,7 +36,13 @@ def test_kjt_input(self, device_str: str) -> None:
3036
features = maybe_td_to_kjt(kjt)
3137
self.assertEqual(features, kjt)
3238

33-
@repeat_test(device_str=["cpu", "cuda", "meta"])
39+
@given(device_str=st.sampled_from(["cpu", "cuda", "meta"]))
40+
@settings(verbosity=Verbosity.verbose, max_examples=5, deadline=None)
41+
# pyre-ignore[56]
42+
@unittest.skipIf(
43+
torch.cuda.device_count() <= 0,
44+
"CUDA is not available",
45+
)
3446
def test_td_kjt(self, device_str: str) -> None:
3547
device = torch.device(device_str)
3648
values = torch.tensor([0, 1, 2, 3, 2, 3, 4], device=device)

0 commit comments

Comments
 (0)