Skip to content

Commit

Permalink
rotation
Browse files Browse the repository at this point in the history
  • Loading branch information
0x00b1 committed Apr 6, 2024
1 parent bb22697 commit e95cc49
Showing 1 changed file with 23 additions and 10 deletions.
33 changes: 23 additions & 10 deletions tests/beignet/test__apply_rotation_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

@hypothesis.strategies.composite
def _strategy(function):
batch_size = function(
size = function(
hypothesis.strategies.integers(
min_value=1,
max_value=8,
Expand All @@ -18,27 +18,35 @@ def _strategy(function):
input = function(
hypothesis.extra.numpy.arrays(
numpy.float64,
(batch_size, 3),
(size, 3),
elements={
"allow_infinity": False,
"allow_nan": False,
"min_value": -1e3,
"max_value": +1e3,
"min_value": numpy.finfo(numpy.float32).min,
"max_value": numpy.finfo(numpy.float32).max,
},
),
)

rotation = Rotation.random(batch_size)
rotation = Rotation.random(size)

inverse = function(hypothesis.strategies.booleans())

return (
{
"input": torch.from_numpy(input),
"rotation": torch.from_numpy(rotation.as_matrix()),
"input": torch.from_numpy(
input,
),
"rotation": torch.from_numpy(
rotation.as_matrix(),
),
"inverse": inverse,
},
torch.from_numpy(rotation.apply(input, inverse)),
torch.from_numpy(
rotation.apply(
input,
inverse,
),
),
)


Expand All @@ -47,6 +55,11 @@ def test_apply_rotation_matrix(data):
parameters, expected = data

torch.testing.assert_close(
beignet.apply_rotation_matrix(**parameters),
beignet.apply_rotation_matrix(
**parameters,
),
expected,
equal_nan=True,
atol=1e-06,
rtol=1e-06,
)

0 comments on commit e95cc49

Please sign in to comment.