diff --git a/src/beignet/datasets/_random_rotation_vector_dataset.py b/src/beignet/datasets/_random_rotation_vector_dataset.py index f26e685b56..4f6d43b77d 100644 --- a/src/beignet/datasets/_random_rotation_vector_dataset.py +++ b/src/beignet/datasets/_random_rotation_vector_dataset.py @@ -1,24 +1,36 @@ -from typing import Callable, Optional +from typing import Callable, Generator -from torch import Tensor +import torch from torch.utils.data import Dataset +import beignet +from beignet.transforms import Transform -class RandomAxisAngleDataset(Dataset): + +class RandomRotationVectorDataset(Dataset): def __init__( self, size: int, + degrees: bool = False, *, - transform: Optional[Callable] = None, - ) -> None: - super().__init__() - - self.size = size - - self.transform = transform - - def __getitem__(self, index: int) -> Tensor: - raise NotImplementedError - - def __len__(self) -> int: - return self.size + device: torch.device | None = None, + dtype: torch.dtype | None = None, + generator: Generator | None = None, + layout: torch.layout | None = torch.strided, + pin_memory: bool | None = False, + requires_grad: bool | None = False, + transform: Callable | Transform | None = None, + ): + super().__init__( + beignet.random_rotation_vector( + size, + degrees, + generator=generator, + dtype=dtype, + layout=layout, + device=device, + requires_grad=requires_grad, + pin_memory=pin_memory, + ), + transform=transform, + )