From 639543730d4e0912a598b19015f63a023008bdad Mon Sep 17 00:00:00 2001 From: Allen Goodman Date: Wed, 22 May 2024 12:50:34 -0400 Subject: [PATCH] RandomEulerAngleDataset --- .../datasets/__random_rotation_dataset.py | 5 ++-- .../datasets/_random_euler_angle_dataset.py | 25 ++++++++++--------- 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/src/beignet/datasets/__random_rotation_dataset.py b/src/beignet/datasets/__random_rotation_dataset.py index 325fbb0913..edfc28373e 100644 --- a/src/beignet/datasets/__random_rotation_dataset.py +++ b/src/beignet/datasets/__random_rotation_dataset.py @@ -7,15 +7,16 @@ class RandomRotationDataset(Dataset): - data: Tensor - def __init__( self, + data: Tensor, *, transform: Callable | Transform | None = None, ) -> None: super().__init__() + self.data = data + self.transform = transform def __getitem__(self, index: int) -> Tensor: diff --git a/src/beignet/datasets/_random_euler_angle_dataset.py b/src/beignet/datasets/_random_euler_angle_dataset.py index 923a46243b..48444f4434 100644 --- a/src/beignet/datasets/_random_euler_angle_dataset.py +++ b/src/beignet/datasets/_random_euler_angle_dataset.py @@ -23,16 +23,17 @@ def __init__( requires_grad: bool | None = False, transform: Callable | Transform | None = None, ) -> None: - self.data = beignet.random_euler_angle( - size, - axes, - degrees, - generator=generator, - dtype=dtype, - layout=layout, - device=device, - requires_grad=requires_grad, - pin_memory=pin_memory, + super().__init__( + beignet.random_euler_angle( + size, + axes, + degrees, + generator=generator, + dtype=dtype, + layout=layout, + device=device, + requires_grad=requires_grad, + pin_memory=pin_memory, + ), + transform=transform, ) - - super().__init__(transform=transform)