Skip to content

Commit

Permalink
RandomRotationVectorDataset
Browse files Browse the repository at this point in the history
  • Loading branch information
0x00b1 committed May 22, 2024
1 parent 55bc0a5 commit dd7ef6e
Showing 1 changed file with 28 additions and 16 deletions.
44 changes: 28 additions & 16 deletions src/beignet/datasets/_random_rotation_vector_dataset.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,36 @@
from typing import Callable, Optional
from typing import Callable, Generator

from torch import Tensor
import torch
from torch.utils.data import Dataset

import beignet
from beignet.transforms import Transform

class RandomAxisAngleDataset(Dataset):

class RandomRotationVectorDataset(Dataset):
def __init__(
self,
size: int,
degrees: bool = False,
*,
transform: Optional[Callable] = None,
) -> None:
super().__init__()

self.size = size

self.transform = transform

def __getitem__(self, index: int) -> Tensor:
raise NotImplementedError

def __len__(self) -> int:
return self.size
device: torch.device | None = None,
dtype: torch.dtype | None = None,
generator: Generator | None = None,
layout: torch.layout | None = torch.strided,
pin_memory: bool | None = False,
requires_grad: bool | None = False,
transform: Callable | Transform | None = None,
):
super().__init__(
beignet.random_rotation_vector(
size,
degrees,
generator=generator,
dtype=dtype,
layout=layout,
device=device,
requires_grad=requires_grad,
pin_memory=pin_memory,
),
transform=transform,
)

0 comments on commit dd7ef6e

Please sign in to comment.