From cda97b9d4c6df3d443f984ffcd613817070e63a9 Mon Sep 17 00:00:00 2001 From: Allen Goodman Date: Tue, 4 Jun 2024 16:21:34 -0400 Subject: [PATCH] tests --- .../_random_rotation_matrix_dataset.py | 4 +- .../_random_rotation_vector_dataset.py | 4 +- .../test__random_euler_angle_dataset.py | 58 ++++++++++++++++++- .../test__random_quaternion_dataset.py | 38 +++++++++++- .../test__random_rotation_matrix_dataset.py | 35 ++++++++++- .../test__random_rotation_vector_dataset.py | 38 +++++++++++- 6 files changed, 165 insertions(+), 12 deletions(-) diff --git a/src/beignet/datasets/_random_rotation_matrix_dataset.py b/src/beignet/datasets/_random_rotation_matrix_dataset.py index 559ee76070..a033892e4c 100644 --- a/src/beignet/datasets/_random_rotation_matrix_dataset.py +++ b/src/beignet/datasets/_random_rotation_matrix_dataset.py @@ -1,13 +1,13 @@ from typing import Callable, Generator import torch -from torch.utils.data import Dataset import beignet +from beignet.datasets.__random_rotation_dataset import RandomRotationDataset from beignet.transforms import Transform -class RandomRotationMatrixDataset(Dataset): +class RandomRotationMatrixDataset(RandomRotationDataset): def __init__( self, size: int, diff --git a/src/beignet/datasets/_random_rotation_vector_dataset.py b/src/beignet/datasets/_random_rotation_vector_dataset.py index 176b5a43de..d1e0297140 100644 --- a/src/beignet/datasets/_random_rotation_vector_dataset.py +++ b/src/beignet/datasets/_random_rotation_vector_dataset.py @@ -1,13 +1,13 @@ from typing import Callable, Generator import torch -from torch.utils.data import Dataset import beignet +from beignet.datasets.__random_rotation_dataset import RandomRotationDataset from beignet.transforms import Transform -class RandomRotationVectorDataset(Dataset): +class RandomRotationVectorDataset(RandomRotationDataset): def __init__( self, size: int, diff --git a/tests/beignet/datasets/test__random_euler_angle_dataset.py b/tests/beignet/datasets/test__random_euler_angle_dataset.py index f94e960867..8cf996ef2e 100644 --- a/tests/beignet/datasets/test__random_euler_angle_dataset.py +++ b/tests/beignet/datasets/test__random_euler_angle_dataset.py @@ -1,3 +1,57 @@ +import beignet +import hypothesis.strategies +from beignet.datasets import RandomEulerAngleDataset + + +@hypothesis.strategies.composite +def _strategy(function): + size = function( + hypothesis.strategies.integers( + min_value=1, + max_value=8, + ), + ) + + axes = function( + hypothesis.strategies.sampled_from( + [ + "xyz", + "xzy", + "yxz", + "yzx", + "zxy", + "zyx", + "XYZ", + "XZY", + "YXZ", + "YZX", + "ZXY", + "ZYX", + ] + ), + ) + + degrees = function(hypothesis.strategies.booleans()) + + return ( + { + "size": size, + "axes": axes, + "degrees": degrees, + }, + beignet.random_euler_angle(size, axes=axes, degrees=degrees), + ) + + class TestRandomEulerAngleDataset: - def test___init__(self): - assert True + @hypothesis.given(_strategy()) + def test___init__(self, data): + parameters, output = data + + dataset = RandomEulerAngleDataset(**parameters) + + assert dataset.data.shape == output.shape + + assert dataset.data.dtype == output.dtype + + assert dataset.data.layout == output.layout diff --git a/tests/beignet/datasets/test__random_quaternion_dataset.py b/tests/beignet/datasets/test__random_quaternion_dataset.py index b4bf044d28..5bb9bf3731 100644 --- a/tests/beignet/datasets/test__random_quaternion_dataset.py +++ b/tests/beignet/datasets/test__random_quaternion_dataset.py @@ -1,3 +1,37 @@ +import beignet +import hypothesis.strategies +from beignet.datasets import RandomQuaternionDataset + + +@hypothesis.strategies.composite +def _strategy(function): + size = function( + hypothesis.strategies.integers( + min_value=1, + max_value=8, + ), + ) + + canonical = function(hypothesis.strategies.booleans()) + + return ( + { + "size": size, + "canonical": canonical, + }, + beignet.random_quaternion(size, canonical=canonical), + ) + + class TestRandomQuaternionDataset: - def test___init__(self): - assert True + @hypothesis.given(_strategy()) + def test___init__(self, data): + parameters, output = data + + dataset = RandomQuaternionDataset(**parameters) + + assert dataset.data.shape == output.shape + + assert dataset.data.dtype == output.dtype + + assert dataset.data.layout == output.layout diff --git a/tests/beignet/datasets/test__random_rotation_matrix_dataset.py b/tests/beignet/datasets/test__random_rotation_matrix_dataset.py index 7a46838fab..0c4ec70303 100644 --- a/tests/beignet/datasets/test__random_rotation_matrix_dataset.py +++ b/tests/beignet/datasets/test__random_rotation_matrix_dataset.py @@ -1,3 +1,34 @@ +import beignet +import hypothesis.strategies +from beignet.datasets import RandomRotationMatrixDataset + + +@hypothesis.strategies.composite +def _strategy(function): + size = function( + hypothesis.strategies.integers( + min_value=1, + max_value=8, + ), + ) + + return ( + { + "size": size, + }, + beignet.random_rotation_matrix(size), + ) + + class TestRandomRotationMatrixDataset: - def test___init__(self): - assert True + @hypothesis.given(_strategy()) + def test___init__(self, data): + parameters, output = data + + dataset = RandomRotationMatrixDataset(**parameters) + + assert dataset.data.shape == output.shape + + assert dataset.data.dtype == output.dtype + + assert dataset.data.layout == output.layout diff --git a/tests/beignet/datasets/test__random_rotation_vector_dataset.py b/tests/beignet/datasets/test__random_rotation_vector_dataset.py index 5ced5e380d..26d50ab2e7 100644 --- a/tests/beignet/datasets/test__random_rotation_vector_dataset.py +++ b/tests/beignet/datasets/test__random_rotation_vector_dataset.py @@ -1,3 +1,37 @@ +import beignet +import hypothesis.strategies +from beignet.datasets import RandomRotationVectorDataset + + +@hypothesis.strategies.composite +def _strategy(function): + size = function( + hypothesis.strategies.integers( + min_value=1, + max_value=8, + ), + ) + + degrees = function(hypothesis.strategies.booleans()) + + return ( + { + "size": size, + "degrees": degrees, + }, + beignet.random_rotation_vector(size, degrees=degrees), + ) + + class TestRandomRotationVectorDataset: - def test___init__(self): - assert True + @hypothesis.given(_strategy()) + def test___init__(self, data): + parameters, output = data + + dataset = RandomRotationVectorDataset(**parameters) + + assert dataset.data.shape == output.shape + + assert dataset.data.dtype == output.dtype + + assert dataset.data.layout == output.layout