diff --git a/trainers/baseline_trainer.py b/trainers/baseline_trainer.py index 8b637f1..34a2889 100644 --- a/trainers/baseline_trainer.py +++ b/trainers/baseline_trainer.py @@ -21,7 +21,7 @@ -------parameters--------- epochs: 80 - batch size: 96, 96, 96, 60, 60, 60, 60 + batch size: 128, 128, 128, 64, 64, 64, 64 use SGD+0.9momentum w/o nestrov @@ -41,8 +41,8 @@ resnet18_planet, resnet34_planet, resnet50_planet, densenet121, densenet169, densenet161, ] -batch_size = [96, 96, 96, 60, - 60, 32, 60] +batch_size = [128, 128, 128, 64, + 64, 64, 64] def get_dataloader(batch_size): @@ -97,12 +97,13 @@ def train_baselines(): train_data.batch_size = batch val_data.batch_size = batch - num_epoches = 50 #100 + num_epoches = 100 #100 print_every_iter = 20 epoch_test = 1 # optimizer - optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=0.0005) + # optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=0.0005) + optimizer = optim.Adam(net.parameters(), lr=1e-4, weight_decay=5e-4) smooth_loss = 0.0 train_loss = np.nan @@ -116,7 +117,7 @@ def train_baselines(): # train loss averaged every epoch total_epoch_loss = 0.0 - lr_schedule(epoch, optimizer) + # lr_schedule(epoch, optimizer) rate = get_learning_rate(optimizer)[0] # check @@ -165,7 +166,7 @@ def train_baselines(): # save if the current loss is better if test_loss < best_test_loss: - torch.save(net, '../models/{}.pth'.format(name)) + torch.save(net.state_dict(), '../models/{}.pth'.format(name)) best_test_loss = test_loss logger.add_record('train_loss', total_epoch_loss) diff --git a/util.py b/util.py index 4c2bdf6..d4341cb 100644 --- a/util.py +++ b/util.py @@ -68,7 +68,7 @@ def optimize_threshold(fnames, labels, resolution): r /= resolution threshold[i] = r # labels = get_labels(pred, threshold) - preds = (results > threshold).dtype(np.int32) + preds = (results > threshold).astype(np.int32) score = f2_score(preds, labels) if score > best_score: best_thresh = r @@ -132,16 +132,14 @@ def get_learning_rate(optimizer): def lr_schedule(epoch, optimizer): - if 0 <= epoch < 10: - lr = 1e-1 - elif 10 <= epoch < 25: - lr = 0.01 - elif 25 <= epoch < 35: - lr = 0.005 - elif 35 <= epoch < 40: - lr = 0.001 + if 0 <= epoch < 20: + lr = 1e-4 + elif 20 <= epoch < 35: + lr = 9e-5 + elif 35 <= epoch < 45: + lr = 5e-5 else: - lr = 0.0001 + lr = 5e-5 for para_group in optimizer.param_groups: para_group['lr'] = lr @@ -236,6 +234,6 @@ def save_time(self, start_time, end_time): height=256, width=256 ) - files = ['densenet121.txt', 'densenet161.txt', 'densenet169.txt', 'resnet18_planet.txt', - 'resnet34_planet.txt', 'resnet50_planet.txt'] + files = ['probs/densenet121.txt', 'probs/densenet161.txt', 'probs/densenet169.txt', 'probs/resnet18_planet.txt', + 'probs/resnet34_planet.txt', 'probs/resnet50_planet.txt'] optimize_threshold(files, resolution=500, labels=validation.labels)