-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbatch_generator.py
55 lines (40 loc) · 1.66 KB
/
batch_generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from typing import Any, Generator
import numpy as np
from gns.utils.shuffle_inplace import shuffle_inplace
def batch_generator(
data, batch_size=32, epochs=None, shuffle=True
) -> Generator[list, Any, Any]:
"""
Batch generator.
Iterates over data with a given number of epochs, returns packages with a batch_size limit as a python generator (yield)
by one value.
Args:
data: Numpy array np.array or a list of such arrays np.arrays with the same first dimension
batch_size: the number of samples in the batch (batch)
epochs: the number of attempts to iterate on the data (by default None -iterate infinitely);
shuffle: there is no need to shuffle data before the start of the epoch
Returns:
butch of a given size batch_size
"""
# assert parameters
if not isinstance(data, (list, tuple)):
data = [data]
if len(data) < 1:
raise ValueError("Data should not be empty.")
if len({len(item) for item in data}) > 1:
raise ValueError("All inputs should have the same length (__len__).")
if epochs is None or epochs == -1:
epochs = np.inf
batches_per_epoch_count = int(np.ceil(len(data[0]) / batch_size))
epoch_number = 0
while epoch_number < epochs:
epoch_number += 1
if shuffle:
shuffle_inplace(*data)
for batch in range(batches_per_epoch_count):
start = batch * batch_size
stop = min(start + batch_size, len(data[0]))
yield_generate = [item[start:stop] for item in data]
if len(data) == 1:
yield_generate = yield_generate[0]
yield yield_generate