11
11
import unittest
12
12
13
13
import torch
14
+ from hypothesis import given , settings , strategies as st , Verbosity
14
15
from tensordict import TensorDict
15
16
from torchrec .sparse .jagged_tensor import KeyedJaggedTensor
16
17
from torchrec .sparse .tensor_dict import maybe_td_to_kjt
17
- from torchrec .sparse .tests .utils import repeat_test
18
18
19
19
20
20
class TestTensorDIct (unittest .TestCase ):
21
- @repeat_test (device_str = ["cpu" , "cuda" , "meta" ])
21
+ @given (device_str = st .sampled_from (["cpu" , "cuda" , "meta" ]))
22
+ @settings (verbosity = Verbosity .verbose , max_examples = 5 , deadline = None )
23
+ # pyre-ignore[56]
24
+ @unittest .skipIf (
25
+ torch .cuda .device_count () <= 0 ,
26
+ "CUDA is not available" ,
27
+ )
22
28
def test_kjt_input (self , device_str : str ) -> None :
23
29
device = torch .device (device_str )
24
30
values = torch .tensor ([0 , 1 , 2 , 3 , 2 , 3 , 4 ], device = device )
@@ -30,7 +36,13 @@ def test_kjt_input(self, device_str: str) -> None:
30
36
features = maybe_td_to_kjt (kjt )
31
37
self .assertEqual (features , kjt )
32
38
33
- @repeat_test (device_str = ["cpu" , "cuda" , "meta" ])
39
+ @given (device_str = st .sampled_from (["cpu" , "cuda" , "meta" ]))
40
+ @settings (verbosity = Verbosity .verbose , max_examples = 5 , deadline = None )
41
+ # pyre-ignore[56]
42
+ @unittest .skipIf (
43
+ torch .cuda .device_count () <= 0 ,
44
+ "CUDA is not available" ,
45
+ )
34
46
def test_td_kjt (self , device_str : str ) -> None :
35
47
device = torch .device (device_str )
36
48
values = torch .tensor ([0 , 1 , 2 , 3 , 2 , 3 , 4 ], device = device )
0 commit comments