Skip to content

Commit

Permalink
RandomEulerAngleDataset
Browse files Browse the repository at this point in the history
  • Loading branch information
0x00b1 committed May 22, 2024
1 parent 6395437 commit f00dde7
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 22 deletions.
2 changes: 1 addition & 1 deletion src/beignet/datasets/__random_rotation_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def __init__(
data: Tensor,
*,
transform: Callable | Transform | None = None,
) -> None:
):
super().__init__()

self.data = data
Expand Down
2 changes: 1 addition & 1 deletion src/beignet/datasets/_random_euler_angle_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(
pin_memory: bool | None = False,
requires_grad: bool | None = False,
transform: Callable | Transform | None = None,
) -> None:
):
super().__init__(
beignet.random_euler_angle(
size,
Expand Down
45 changes: 26 additions & 19 deletions src/beignet/datasets/_random_quaternion_dataset.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,37 @@
from typing import Callable
from typing import Callable, Generator

from torch import Generator, Tensor
from torch.utils.data import Dataset
import torch

import beignet
from beignet.transforms import Transform

from .__random_rotation_dataset import RandomRotationDataset

class RandomQuaternionDataset(Dataset):

class RandomQuaternionDataset(RandomRotationDataset):
def __init__(
self,
size: int,
canonical: bool = False,
*,
generator: Generator = None,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
generator: Generator | None = None,
layout: torch.layout | None = torch.strided,
pin_memory: bool | None = False,
requires_grad: bool | None = False,
transform: Callable | Transform | None = None,
) -> None:
super().__init__()

self.size = size

self.generator = generator

self.transform = transform

def __getitem__(self, _index: int) -> Tensor:
return beignet.random_quaternion(1, generator=self.generator)

def __len__(self) -> int:
return self.size
):
super().__init__(
beignet.random_quaternion(
size,
canonical,
generator=generator,
dtype=dtype,
layout=layout,
device=device,
requires_grad=requires_grad,
pin_memory=pin_memory,
),
transform=transform,
)
2 changes: 1 addition & 1 deletion src/beignet/datasets/_random_rotation_matrix_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def __init__(
size: int,
*,
transform: Optional[Callable] = None,
) -> None:
):
super().__init__()

self.size = size
Expand Down

0 comments on commit f00dde7

Please sign in to comment.