Skip to content

Commit

Permalink
train baseline models
Browse files Browse the repository at this point in the history
  • Loading branch information
jxu7 committed Jun 15, 2017
1 parent e7c397a commit 4f0c747
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 19 deletions.
15 changes: 8 additions & 7 deletions trainers/baseline_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
-------parameters---------
epochs: 80
batch size: 96, 96, 96, 60, 60, 60, 60
batch size: 128, 128, 128, 64, 64, 64, 64
use SGD+0.9momentum w/o nestrov
Expand All @@ -41,8 +41,8 @@
resnet18_planet, resnet34_planet, resnet50_planet,
densenet121, densenet169, densenet161,
]
batch_size = [96, 96, 96, 60,
60, 32, 60]
batch_size = [128, 128, 128, 64,
64, 64, 64]


def get_dataloader(batch_size):
Expand Down Expand Up @@ -97,12 +97,13 @@ def train_baselines():
train_data.batch_size = batch
val_data.batch_size = batch

num_epoches = 50 #100
num_epoches = 100 #100
print_every_iter = 20
epoch_test = 1

# optimizer
optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=0.0005)
# optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=0.0005)
optimizer = optim.Adam(net.parameters(), lr=1e-4, weight_decay=5e-4)

smooth_loss = 0.0
train_loss = np.nan
Expand All @@ -116,7 +117,7 @@ def train_baselines():
# train loss averaged every epoch
total_epoch_loss = 0.0

lr_schedule(epoch, optimizer)
# lr_schedule(epoch, optimizer)

rate = get_learning_rate(optimizer)[0] # check

Expand Down Expand Up @@ -165,7 +166,7 @@ def train_baselines():

# save if the current loss is better
if test_loss < best_test_loss:
torch.save(net, '../models/{}.pth'.format(name))
torch.save(net.state_dict(), '../models/{}.pth'.format(name))
best_test_loss = test_loss

logger.add_record('train_loss', total_epoch_loss)
Expand Down
22 changes: 10 additions & 12 deletions util.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def optimize_threshold(fnames, labels, resolution):
r /= resolution
threshold[i] = r
# labels = get_labels(pred, threshold)
preds = (results > threshold).dtype(np.int32)
preds = (results > threshold).astype(np.int32)
score = f2_score(preds, labels)
if score > best_score:
best_thresh = r
Expand Down Expand Up @@ -132,16 +132,14 @@ def get_learning_rate(optimizer):


def lr_schedule(epoch, optimizer):
if 0 <= epoch < 10:
lr = 1e-1
elif 10 <= epoch < 25:
lr = 0.01
elif 25 <= epoch < 35:
lr = 0.005
elif 35 <= epoch < 40:
lr = 0.001
if 0 <= epoch < 20:
lr = 1e-4
elif 20 <= epoch < 35:
lr = 9e-5
elif 35 <= epoch < 45:
lr = 5e-5
else:
lr = 0.0001
lr = 5e-5

for para_group in optimizer.param_groups:
para_group['lr'] = lr
Expand Down Expand Up @@ -236,6 +234,6 @@ def save_time(self, start_time, end_time):
height=256,
width=256
)
files = ['densenet121.txt', 'densenet161.txt', 'densenet169.txt', 'resnet18_planet.txt',
'resnet34_planet.txt', 'resnet50_planet.txt']
files = ['probs/densenet121.txt', 'probs/densenet161.txt', 'probs/densenet169.txt', 'probs/resnet18_planet.txt',
'probs/resnet34_planet.txt', 'probs/resnet50_planet.txt']
optimize_threshold(files, resolution=500, labels=validation.labels)

0 comments on commit 4f0c747

Please sign in to comment.