-
Notifications
You must be signed in to change notification settings - Fork 25
/
Copy pathoptimizer.py
50 lines (41 loc) · 1.55 KB
/
optimizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
#!/usr/bin/python
# -*- encoding: utf-8 -*-
import torch
import logging
import sys
from logger import logger
class AdamOptimWrapper(object):
'''
A wrapper of Adam optimizer which allows to adjust the optimizing parameters
according to the stategy presented in the paper
'''
def __init__(self, params, lr, wd=0, t0=15000, t1=25000, *args, **kwargs):
super(AdamOptimWrapper, self).__init__(*args, **kwargs)
self.base_lr = lr
self.wd = wd
self.t0 = t0
self.t1 = t1
self.step_count = 0
self.optim = torch.optim.Adam(params,
lr = self.base_lr,
weight_decay = self.wd)
def step(self):
self.step_count += 1
self.optim.step()
# adjust optimizer parameters
if self.step_count == self.t0:
betas_old = self.optim.param_groups[0]['betas']
self.optim.param_groups[0]['betas'] = (0.5, 0.999)
betas_new = self.optim.param_groups[0]['betas']
logger.info('==> changing adam betas from {} to {}'.format(betas_old, betas_new))
logger.info('==> start droping lr exponentially')
elif self.t0 < self.step_count < self.t1:
lr = self.base_lr * (0.001 ** ((self.step_count + 1.0 - self.t0) / (self.t1 + 1.0 - self.t0)))
for pg in self.optim.param_groups:
pg['lr'] = lr
self.optim.defaults['lr'] = lr
def zero_grad(self):
self.optim.zero_grad()
@property
def lr(self):
return self.optim.param_groups[0]['lr']