-
Notifications
You must be signed in to change notification settings - Fork 18
/
Copy pathsch.py
55 lines (48 loc) · 2.3 KB
/
sch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from tensorflow import keras
import math
class CosineAnnealingScheduler(keras.callbacks.Callback):
"""Cosine annealing scheduler.
"""
def __init__(self, T_max, eta_max, eta_min=0, verbose=1):
super(CosineAnnealingScheduler, self).__init__()
self.T_max = T_max
self.eta_max = eta_max
self.eta_min = eta_min
self.verbose = verbose
def on_epoch_begin(self, epoch, logs=None):
if not hasattr(self.model.optimizer, 'lr'):
raise ValueError('Optimizer must have a "lr" attribute.')
lr = self.eta_min + (self.eta_max - self.eta_min) * (1 + math.cos(math.pi * epoch / self.T_max)) / 2
keras.backend.set_value(self.model.optimizer.lr, lr)
if self.verbose > 0:
print('\nEpoch %05d: CosineAnnealingScheduler setting learning rate to %s.' % (epoch + 1, lr))
def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
logs['lr'] = keras.backend.get_value(self.model.optimizer.lr)
class WarmUpLearningRateScheduler(keras.callbacks.Callback):
"""Warmup learning rate scheduler
"""
def __init__(self, warmup_batches, init_lr, verbose=0):
"""Constructor for warmup learning rate scheduler
Arguments:
warmup_batches {int} -- Number of batch for warmup.
init_lr {float} -- Learning rate after warmup.
Keyword Arguments:
verbose {int} -- 0: quiet, 1: update messages. (default: {0})
"""
super(WarmUpLearningRateScheduler, self).__init__()
self.warmup_batches = warmup_batches
self.init_lr = init_lr
self.verbose = verbose
self.batch_count = 0
self.learning_rates = []
def on_batch_end(self, batch, logs=None):
self.batch_count = self.batch_count + 1
lr = keras.backend.get_value(self.model.optimizer.lr)
self.learning_rates.append(lr)
def on_batch_begin(self, batch, logs=None):
if self.batch_count <= self.warmup_batches:
lr = self.batch_count * self.init_lr / self.warmup_batches
keras.backend.set_value(self.model.optimizer.lr, lr)
if self.verbose > 0:
print('\nBatch %05d: WarmUpLearningRateScheduler setting learning rate to %s.' % (self.batch_count + 1, lr))