-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathoptimisers.py
executable file
·157 lines (133 loc) · 6.35 KB
/
optimisers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
# -*- coding: utf-8 -*-
"""Model optimisers.
This module contains objects implementing (batched) stochastic gradient descent
based optimisation of models.
"""
import time
import logging
from collections import OrderedDict
import numpy as np
import tqdm
logger = logging.getLogger(__name__)
class Optimiser(object):
"""Basic model optimiser."""
def __init__(self, model, error, learning_rule, train_dataset,
valid_dataset=None, data_monitors=None, notebook=False):
"""Create a new optimiser instance.
Args:
model: The model to optimise.
error: The scalar error function to minimise.
learning_rule: Gradient based learning rule to use to minimise
error.
train_dataset: Data provider for training set data batches.
valid_dataset: Data provider for validation set data batches.
data_monitors: Dictionary of functions evaluated on targets and
model outputs (averaged across both full training and
validation data sets) to monitor during training in addition
to the error. Keys should correspond to a string label for
the statistic being evaluated.
notebook: A boolean indicating whether experiments are carried out
in an ipython-notebook, useful for indicating which progress bar styles
to use.
"""
self.model = model
self.error = error
self.learning_rule = learning_rule
self.learning_rule.initialise(self.model.params)
self.train_dataset = train_dataset
self.valid_dataset = valid_dataset
self.data_monitors = OrderedDict([('error', error)])
self.notebook = notebook
if notebook:
self.tqdm_progress = tqdm.tqdm_notebook
else:
self.tqdm_progress = tqdm.tqdm
if data_monitors is not None:
self.data_monitors.update(data_monitors)
def do_training_epoch(self):
"""Do a single training epoch.
This iterates through all batches in training dataset, for each
calculating the gradient of the estimated error given the batch with
respect to all the model parameters and then updates the model
parameters according to the learning rule.
"""
with self.tqdm_progress(total=self.train_dataset.num_batches) as train_progress_bar:
train_progress_bar.set_description("Epoch Progress")
for inputs_batch, targets_batch in self.train_dataset:
activations = self.model.fprop(inputs_batch)
grads_wrt_outputs = self.error.grad(activations[-1], targets_batch)
grads_wrt_params = self.model.grads_wrt_params(
activations, grads_wrt_outputs)
self.learning_rule.update_params(grads_wrt_params)
train_progress_bar.update(1)
def eval_monitors(self, dataset, label):
"""Evaluates the monitors for the given dataset.
Args:
dataset: Dataset to perform evaluation with.
label: Tag to add to end of monitor keys to identify dataset.
Returns:
OrderedDict of monitor values evaluated on dataset.
"""
data_mon_vals = OrderedDict([(key + label, 0.) for key
in self.data_monitors.keys()])
for inputs_batch, targets_batch in dataset:
activations = self.model.fprop(inputs_batch)
for key, data_monitor in self.data_monitors.items():
data_mon_vals[key + label] += data_monitor(
activations[-1], targets_batch)
for key, data_monitor in self.data_monitors.items():
data_mon_vals[key + label] /= dataset.num_batches
return data_mon_vals
def get_epoch_stats(self):
"""Computes training statistics for an epoch.
Returns:
An OrderedDict with keys corresponding to the statistic labels and
values corresponding to the value of the statistic.
"""
epoch_stats = OrderedDict()
epoch_stats.update(self.eval_monitors(self.train_dataset, '(train)'))
if self.valid_dataset is not None:
epoch_stats.update(self.eval_monitors(
self.valid_dataset, '(valid)'))
epoch_stats['params_penalty'] = self.model.params_penalty()
return epoch_stats
def log_stats(self, epoch, epoch_time, stats):
"""Outputs stats for a training epoch to a logger.
Args:
epoch (int): Epoch counter.
epoch_time: Time taken in seconds for the epoch to complete.
stats: Monitored stats for the epoch.
"""
logger.info('Epoch {0}: {1:.1f}s to complete\n {2}'.format(
epoch, epoch_time,
', '.join(['{0}={1:.2e}'.format(k, v) for (k, v) in stats.items()])
))
def train(self, num_epochs, stats_interval=5):
"""Trains a model for a set number of epochs.
Args:
num_epochs: Number of epochs (complete passes through trainin
dataset) to train for.
stats_interval: Training statistics will be recorded and logged
every `stats_interval` epochs.
Returns:
Tuple with first value being an array of training run statistics
and the second being a dict mapping the labels for the statistics
recorded to their column index in the array.
"""
start_train_time = time.time()
run_stats = [list(self.get_epoch_stats().values())]
with self.tqdm_progress(total=num_epochs) as progress_bar:
progress_bar.set_description("Experiment Progress")
for epoch in range(1, num_epochs + 1):
start_time = time.time()
self.do_training_epoch()
epoch_time = time.time() - start_time
if epoch % stats_interval == 0:
stats = self.get_epoch_stats()
self.log_stats(epoch, epoch_time, stats)
run_stats.append(list(stats.values()))
#epoch_info = str(list(stats.values()))
progress_bar.update(1)
finish_train_time = time.time()
total_train_time = finish_train_time - start_train_time
return np.array(run_stats), {k: i for i, k in enumerate(stats.keys())}, total_train_time