Skip to content

Commit

Permalink
fpnet62_wd_1e-4_adam_rotate
Browse files Browse the repository at this point in the history
  • Loading branch information
Junhong Xu committed Jun 8, 2017
1 parent 249fb5a commit f0b63a9
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 20 deletions.
14 changes: 14 additions & 0 deletions datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,20 @@
std = [0.16730586, 0.14391145, 0.13747531]


class RandomVerticalFLip(object):
def __call__(self, img):
if random.random() < 0.5:
img = img.transpose(Image.FLIP_TOP_BOTTOM)
return img


class RandomRotate(object):
def __call__(self, img):
if random.random() < 0.2:
img = img.rotate(45)
return img


def is_image_file(filename):
return any(filename.endswith(extension) for extension in ['.png', 'jpg', '.jpeg'])

Expand Down
23 changes: 4 additions & 19 deletions trainers/train_pynet.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,6 @@
NAME = 'fpnet62_wd_1e-4_adam_rotate'


class RandomVerticalFLip(object):
def __call__(self, img):
if random.random() < 0.5:
img = img.transpose(Image.FLIP_TOP_BOTTOM)
return img


class RandomRotate(object):
def __call__(self, img):
if random.random() < 0.5:
rotation = np.random.randint(1, 90)
img = img.rotate(rotation)
return img


def get_optimizer(model, pretrained=True, lr=5e-5, weight_decay=5e-5):
if pretrained:
# no pretrained yet
Expand All @@ -33,13 +18,13 @@ def get_optimizer(model, pretrained=True, lr=5e-5, weight_decay=5e-5):

def lr_schedule(epoch, optimizer):
if epoch < 10:
lr = 5e-4
lr = 9e-4
elif 10 <= epoch <= 20:
lr = 1e-4
lr = 5e-4
elif 25 < epoch <= 45:
lr = 5e-5
lr = 1e-4
else:
lr = 1e-5
lr = 5e-5

for param_group in optimizer.param_groups:
param_group['lr'] = lr
Expand Down
1 change: 0 additions & 1 deletion util.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ def f2_score(y_true, y_pred):
return fbeta_score(y_true, y_pred, beta=2, average='samples')



class Logger(object):
def __init__(self, save_dir, name):
self.save_dir = save_dir
Expand Down

0 comments on commit f0b63a9

Please sign in to comment.