-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata_generator.py
33 lines (28 loc) · 1.28 KB
/
data_generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# Function to generate a training batch for the LSTM model.
import numpy as np
from util import vocabulary_size, char2id
class BatchGenerator(object):
def __init__(self, text, batch_size, num_unrollings):
self._text = text
self._text_size = len(text)
self._batch_size = batch_size
self._num_unrollings = num_unrollings
segment = self._text_size // batch_size
self._cursor = [offset * segment for offset in range(batch_size)]
self._last_batch = self._next_batch()
def _next_batch(self):
"""Generate a single batch from the current cursor position in the data."""
batch = np.zeros(shape=(self._batch_size, vocabulary_size), dtype=np.float)
for b in range(self._batch_size):
batch[b, char2id(self._text[self._cursor[b]])] = 1.0
self._cursor[b] = (self._cursor[b] + 1) % self._text_size
return batch
def next(self):
"""Generate the next array of batches from the data. The array consists of
the last batch of the previous array, followed by num_unrollings new ones.
"""
batches = [self._last_batch]
for step in range(self._num_unrollings):
batches.append(self._next_batch())
self._last_batch = batches[-1]
return batches