-
Notifications
You must be signed in to change notification settings - Fork 7
/
scheduler.py
127 lines (108 loc) · 7.12 KB
/
scheduler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from collections.abc import Iterable
from math import log, cos, pi, floor
from torch.optim.lr_scheduler import _LRScheduler
class CyclicCosineDecayLR(_LRScheduler):
def __init__(self,
optimizer,
init_decay_epochs,
min_decay_lr,
restart_interval=None,
restart_interval_multiplier=None,
restart_lr=None,
warmup_epochs=None,
warmup_start_lr=None,
last_epoch=-1,
verbose=False):
"""
Initialize new CyclicCosineDecayLR object.
:param optimizer: (Optimizer) - Wrapped optimizer.
:param init_decay_epochs: (int) - Number of initial decay epochs.
:param min_decay_lr: (float or iterable of floats) - Learning rate at the end of decay.
:param restart_interval: (int) - Restart interval for fixed cycles.
Set to None to disable cycles. Default: None.
:param restart_interval_multiplier: (float) - Multiplication coefficient for geometrically increasing cycles.
Default: None.
:param restart_lr: (float or iterable of floats) - Learning rate when cycle restarts.
If None, optimizer's learning rate will be used. Default: None.
:param warmup_epochs: (int) - Number of warmup epochs. Set to None to disable warmup. Default: None.
:param warmup_start_lr: (float or iterable of floats) - Learning rate at the beginning of warmup.
Must be set if warmup_epochs is not None. Default: None.
:param last_epoch: (int) - The index of the last epoch. This parameter is used when resuming a training job. Default: -1.
:param verbose: (bool) - If True, prints a message to stdout for each update. Default: False.
"""
if not isinstance(init_decay_epochs, int) or init_decay_epochs < 1:
raise ValueError("init_decay_epochs must be positive integer, got {} instead".format(init_decay_epochs))
if isinstance(min_decay_lr, Iterable) and len(min_decay_lr) != len(optimizer.param_groups):
raise ValueError("Expected len(min_decay_lr) to be equal to len(optimizer.param_groups), "
"got {} and {} instead".format(len(min_decay_lr), len(optimizer.param_groups)))
if restart_interval is not None and (not isinstance(restart_interval, int) or restart_interval < 1):
raise ValueError("restart_interval must be positive integer, got {} instead".format(restart_interval))
if restart_interval_multiplier is not None and \
(not isinstance(restart_interval_multiplier, float) or restart_interval_multiplier <= 0):
raise ValueError("restart_interval_multiplier must be positive float, got {} instead".format(
restart_interval_multiplier))
if isinstance(restart_lr, Iterable) and len(restart_lr) != len(optimizer.param_groups):
raise ValueError("Expected len(restart_lr) to be equal to len(optimizer.param_groups), "
"got {} and {} instead".format(len(restart_lr), len(optimizer.param_groups)))
if warmup_epochs is not None:
if not isinstance(warmup_epochs, int) or warmup_epochs < 1:
raise ValueError(
"Expected warmup_epochs to be positive integer, got {} instead".format(type(warmup_epochs)))
if warmup_start_lr is None:
raise ValueError("warmup_start_lr must be set when warmup_epochs is not None")
if not (isinstance(warmup_start_lr, float) or isinstance(warmup_start_lr, Iterable)):
raise ValueError("warmup_start_lr must be either float or iterable of floats, got {} instead".format(
warmup_start_lr))
if isinstance(warmup_start_lr, Iterable) and len(warmup_start_lr) != len(optimizer.param_groups):
raise ValueError("Expected len(warmup_start_lr) to be equal to len(optimizer.param_groups), "
"got {} and {} instead".format(len(warmup_start_lr), len(optimizer.param_groups)))
group_num = len(optimizer.param_groups)
self._warmup_start_lr = [warmup_start_lr] * group_num if isinstance(warmup_start_lr, float) else warmup_start_lr
self._warmup_epochs = 0 if warmup_epochs is None else warmup_epochs
self._init_decay_epochs = init_decay_epochs
self._min_decay_lr = [min_decay_lr] * group_num if isinstance(min_decay_lr, float) else min_decay_lr
self._restart_lr = [restart_lr] * group_num if isinstance(restart_lr, float) else restart_lr
self._restart_interval = restart_interval
self._restart_interval_multiplier = restart_interval_multiplier
super(CyclicCosineDecayLR, self).__init__(optimizer, last_epoch, verbose=verbose)
def get_lr(self):
if self._warmup_epochs > 0 and self.last_epoch < self._warmup_epochs:
return self._calc(self.last_epoch,
self._warmup_epochs,
self._warmup_start_lr,
self.base_lrs)
elif self.last_epoch < self._init_decay_epochs + self._warmup_epochs:
return self._calc(self.last_epoch - self._warmup_epochs,
self._init_decay_epochs,
self.base_lrs,
self._min_decay_lr)
else:
if self._restart_interval is not None:
if self._restart_interval_multiplier is None:
cycle_epoch = (self.last_epoch - self._init_decay_epochs - self._warmup_epochs) % self._restart_interval
lrs = self.base_lrs if self._restart_lr is None else self._restart_lr
return self._calc(cycle_epoch,
self._restart_interval,
lrs,
self._min_decay_lr)
else:
n = self._get_n(self.last_epoch - self._warmup_epochs - self._init_decay_epochs)
sn_prev = self._partial_sum(n)
cycle_epoch = self.last_epoch - sn_prev - self._warmup_epochs - self._init_decay_epochs
interval = self._restart_interval * self._restart_interval_multiplier ** n
lrs = self.base_lrs if self._restart_lr is None else self._restart_lr
return self._calc(cycle_epoch,
interval,
lrs,
self._min_decay_lr)
else:
return self._min_decay_lr
def _calc(self, t, T, lrs, min_lrs):
return [min_lr + (lr - min_lr) * ((1 + cos(pi * t / T)) / 2)
for lr, min_lr in zip(lrs, min_lrs)]
def _get_n(self, epoch):
_t = 1 - (1 - self._restart_interval_multiplier) * epoch / self._restart_interval
return floor(log(_t, self._restart_interval_multiplier))
def _partial_sum(self, n):
return self._restart_interval * (1 - self._restart_interval_multiplier ** n) / (
1 - self._restart_interval_multiplier)