diff --git a/tests/beignet/test__apply_rotation_matrix.py b/tests/beignet/test__apply_rotation_matrix.py index 9336bd6c8f..7f7f38477a 100644 --- a/tests/beignet/test__apply_rotation_matrix.py +++ b/tests/beignet/test__apply_rotation_matrix.py @@ -8,7 +8,7 @@ @hypothesis.strategies.composite def _strategy(function): - batch_size = function( + size = function( hypothesis.strategies.integers( min_value=1, max_value=8, @@ -18,27 +18,35 @@ def _strategy(function): input = function( hypothesis.extra.numpy.arrays( numpy.float64, - (batch_size, 3), + (size, 3), elements={ "allow_infinity": False, - "allow_nan": False, - "min_value": -1e3, - "max_value": +1e3, + "min_value": numpy.finfo(numpy.float32).min, + "max_value": numpy.finfo(numpy.float32).max, }, ), ) - rotation = Rotation.random(batch_size) + rotation = Rotation.random(size) inverse = function(hypothesis.strategies.booleans()) return ( { - "input": torch.from_numpy(input), - "rotation": torch.from_numpy(rotation.as_matrix()), + "input": torch.from_numpy( + input, + ), + "rotation": torch.from_numpy( + rotation.as_matrix(), + ), "inverse": inverse, }, - torch.from_numpy(rotation.apply(input, inverse)), + torch.from_numpy( + rotation.apply( + input, + inverse, + ), + ), ) @@ -47,6 +55,11 @@ def test_apply_rotation_matrix(data): parameters, expected = data torch.testing.assert_close( - beignet.apply_rotation_matrix(**parameters), + beignet.apply_rotation_matrix( + **parameters, + ), expected, + equal_nan=True, + atol=1e-06, + rtol=1e-06, )