Skip to content

Commit

Permalink
RandomRotationMatrixDataset
Browse files Browse the repository at this point in the history
  • Loading branch information
0x00b1 committed May 22, 2024
1 parent f00dde7 commit 55bc0a5
Showing 1 changed file with 24 additions and 14 deletions.
38 changes: 24 additions & 14 deletions src/beignet/datasets/_random_rotation_matrix_dataset.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,34 @@
from typing import Callable, Optional
from typing import Callable, Generator

from torch import Tensor
import torch
from torch.utils.data import Dataset

import beignet
from beignet.transforms import Transform


class RandomRotationMatrixDataset(Dataset):
def __init__(
self,
size: int,
*,
transform: Optional[Callable] = None,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
generator: Generator | None = None,
layout: torch.layout | None = torch.strided,
pin_memory: bool | None = False,
requires_grad: bool | None = False,
transform: Callable | Transform | None = None,
):
super().__init__()

self.size = size

self.transform = transform

def __getitem__(self, index: int) -> Tensor:
raise NotImplementedError

def __len__(self) -> int:
return self.size
super().__init__(
beignet.random_rotation_matrix(
size,
generator=generator,
dtype=dtype,
layout=layout,
device=device,
requires_grad=requires_grad,
pin_memory=pin_memory,
),
transform=transform,
)

0 comments on commit 55bc0a5

Please sign in to comment.