Skip to content

Commit

Permalink
DenseNet161 with SGD
Browse files Browse the repository at this point in the history
  • Loading branch information
Junhong Xu committed Jun 15, 2017
1 parent 4f0c747 commit c93568a
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 29 deletions.
89 changes: 70 additions & 19 deletions planet_models/resnet_planet.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,77 @@
from torchvision.models.resnet import BasicBlock, ResNet, resnet34, resnet50, resnet101, resnet152, resnet18
from torchvision.models.resnet import BasicBlock, ResNet, resnet34, resnet50, resnet101, resnet152, \
resnet18, model_urls, model_zoo
import torch.nn as nn
import math


def resnet34_planet():
return resnet34(False, num_classes=17)


def resnet101_planet():
return resnet101(pretrained=False, num_classes=17)


def resnet50_planet():
return resnet50(pretrained=False, num_classes=17)


def resnet152_planet():
return resnet152(pretrained=False, num_classes=17)


def resnet18_planet():
return resnet18(pretrained=False, num_classes=17)
def resnet34_planet(pretrained=False):
model = resnet34(False, num_classes=17)
if pretrained:
# load model dictionary
model_dict = model.state_dict()
# load pretrained model
pretrained_dict = model_zoo.load_url(model_urls['resnet34'])
# update model dictionary using pretrained model without classifier layer
model_dict.update({key: pretrained_dict[key] for key in pretrained_dict.keys() if 'fc' not in key})
model.load_state_dict(model_dict)

return model


def resnet101_planet(pretrained=False):
model = resnet101(False, num_classes=17)
if pretrained:
# load model dictionary
model_dict = model.state_dict()
# load pretrained model
pretrained_dict = model_zoo.load_url(model_urls['resnet101'])
# update model dictionary using pretrained model without classifier layer
model_dict.update({key: pretrained_dict[key] for key in pretrained_dict.keys() if 'fc' not in key})
model.load_state_dict(model_dict)

return model


def resnet50_planet(pretrained=False):
model = resnet50(False, num_classes=17)
if pretrained:
# load model dictionary
model_dict = model.state_dict()
# load pretrained model
pretrained_dict = model_zoo.load_url(model_urls['resnet50'])
# update model dictionary using pretrained model without classifier layer
model_dict.update({key: pretrained_dict[key] for key in pretrained_dict.keys() if 'fc' not in key})
model.load_state_dict(model_dict)

return model


def resnet152_planet(pretrained=False):
model = resnet50(False, num_classes=17)
if pretrained:
# load model dictionary
model_dict = model.state_dict()
# load pretrained model
pretrained_dict = model_zoo.load_url(model_urls['resnet152'])
# update model dictionary using pretrained model without classifier layer
model_dict.update({key: pretrained_dict[key] for key in pretrained_dict.keys() if 'fc' not in key})
model.load_state_dict(model_dict)

return model


def resnet18_planet(pretrained=False):
model = resnet18(False, num_classes=17)
if pretrained:
# load model dictionary
model_dict = model.state_dict()
# load pretrained model
pretrained_dict = model_zoo.load_url(model_urls['resnet18'])
# update model dictionary using pretrained model without classifier layer
model_dict.update({key: pretrained_dict[key] for key in pretrained_dict.keys() if 'fc' not in key})
model.load_state_dict(model_dict)

return model


def resnet14_planet():
Expand Down
26 changes: 24 additions & 2 deletions trainers/baseline_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,27 @@ def get_dataloader(batch_size):
return train_data_loader, valid_dataloader


def get_optimizer(net, lr, resnet=False, pretrained=False):
if pretrained:
if resnet:
parameters = [
{'params': net.fc.parameters(), 'lr': lr*10},
{'params': net.layer1.parameters(), 'lr': lr},
{'params': net.layer2.parameters(), 'lr': lr},
{'params': net.layer3.parameters(), 'lr': lr},
{'params': net.layer4.parameters(), 'lr': lr}
]
else:
parameters = [
{'params': net.features.parameters(), 'lr': lr},
{'params': net.classifier.parameters(), 'lr': lr * 10}
]
optimizer = optim.SGD(params=parameters, weight_decay=5e-5, momentum=.9)
else:
optimizer = optim.SGD(params=net.parameters(), lr=lr, weight_decay=5e-5, momentum=.9)
return optimizer


def train_baselines():

train_data, val_data = get_dataloader(96)
Expand All @@ -103,7 +124,7 @@ def train_baselines():

# optimizer
# optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=0.0005)
optimizer = optim.Adam(net.parameters(), lr=1e-4, weight_decay=5e-4)
# optimizer = optim.Adam(net.parameters(), lr=1e-4, weight_decay=5e-4)

smooth_loss = 0.0
train_loss = np.nan
Expand All @@ -114,10 +135,11 @@ def train_baselines():
t = time.time()

for epoch in range(num_epoches): # loop over the dataset multiple times
optimizer = get_optimizer(net, lr=.01, pretrained=True, resnet=True if 'resnet' in name else False)
# train loss averaged every epoch
total_epoch_loss = 0.0

# lr_schedule(epoch, optimizer)
lr_schedule(epoch, optimizer, pretrained=True)

rate = get_learning_rate(optimizer)[0] # check

Expand Down
26 changes: 18 additions & 8 deletions util.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,15 +131,25 @@ def get_learning_rate(optimizer):
return lr


def lr_schedule(epoch, optimizer):
if 0 <= epoch < 20:
lr = 1e-4
elif 20 <= epoch < 35:
lr = 9e-5
elif 35 <= epoch < 45:
lr = 5e-5
def lr_schedule(epoch, optimizer, pretrained=False):
if pretrained:
if 0 <= epoch < 10:
lr = 1e-2
elif 10 <= epoch < 25:
lr = 5e-3
elif 25 <= epoch < 40:
lr = 1e-3
else:
lr = 1e-4
else:
lr = 5e-5
if 0 <= epoch < 10:
lr = 1e-1
elif 10 <= epoch < 25:
lr = 5e-2
elif 25 <= epoch < 40:
lr = 1e-2
else:
lr = 1e-3

for para_group in optimizer.param_groups:
para_group['lr'] = lr
Expand Down

0 comments on commit c93568a

Please sign in to comment.