-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata.py
26 lines (20 loc) · 972 Bytes
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from tensor import Tensor
import numpy as np
from typing import Iterator, NamedTuple
BATCH = NamedTuple("BATCH", [("inputs", Tensor), ("targets", Tensor)])
class DataIterator:
def __call__(self, inputs: Tensor, targets: Tensor) -> Iterator[BATCH]:
raise NotImplementedError
class BatchIterator(DataIterator): # Returns batches of data, (inputs: Tensor, targets: Tensor)
def __init__(self, batch_size: int = 32, shuffle: bool = True) -> None:
self.batch_size = batch_size
self.shuffle = shuffle
def __call__(self, inputs: Tensor, targets: Tensor) -> Iterator[BATCH]:
starts = np.arange(0, len(inputs), self.batch_size)
if self.shuffle:
np.random.shuffle(starts)
for start in starts:
end = start + self.batch_size
batch_inputs = inputs[start:end]
batch_targets = targets[start:end]
yield BATCH(batch_inputs, batch_targets)