Skip to content

Commit

Permalink
train on the whole dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
Junhong Xu committed Jun 28, 2017
1 parent 9cbabc0 commit 4ca1625
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 34 deletions.
59 changes: 30 additions & 29 deletions trainers/baseline_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,44 +15,44 @@

"""
A baseline trainer trains the models as followed:
1. ResNet: 18, 34, 50, and 152 (from scratch)
2. DenseNet: 169, 161, and 121 (from scratch)
1. ResNet: 18, 34, 50, and 152 (from pre-trained on 37479)
2. DenseNet: 169, 161, and 121 (from pre-trained on 37479)
-------parameters---------
epochs: 80
batch size: 128, 128, 128, 64, 64, 64, 64
batch size: 128, 128, 64, 64, 64, 64, 50
use SGD+0.9momentum w/o nestrov
weight decay: 5e-4
learning rate: 00-10 epoch: 0.1
10-25 epoch: 0.01
25-35 epoch: 0.005
35-40 epoch: 0.001
40-80 epoch: 0.0001
learning rate: 00-10 epoch: 0.01
10-25 epoch: 0.005
25-35 epoch: 0.001
35-40 epoch: 0.0001
transformations: Rotate, VerticalFlip, HorizontalFlip, RandomCrop
train set: 40479
"""


models = [
# resnet18_planet, resnet34_planet,
# resnet50_planet,
# densenet121, densenet169, densenet161,
resnet152_planet
resnet18_planet, resnet34_planet,
resnet50_planet, densenet121,
densenet169, densenet161,
resnet152_planet
]
batch_size = [# 64, 128,
# 72, 64,
# 64, 64
50
]
batch_size = [
128, 128,
128, 64,
64, 64,
50
]


def get_dataloader(batch_size):
train_data = KgForestDataset(
split='train-37479',
split='train-40479',
transform=Compose(
[
Lambda(lambda x: randomShiftScaleRotate(x, u=0.75, shift_limit=6, scale_limit=6, rotate_limit=45)),
Expand All @@ -68,12 +68,12 @@ def get_dataloader(batch_size):
train_data_loader = DataLoader(batch_size=batch_size, dataset=train_data, shuffle=True)

validation = KgForestDataset(
split='validation-3000',
split='valid-8000',
transform=Compose(
[
# Lambda(lambda x: randomShiftScaleRotate(x, u=0.75, shift_limit=6, scale_limit=6, rotate_limit=45)),
# Lambda(lambda x: randomFlip(x)),
# Lambda(lambda x: randomTranspose(x)),
# Lambda(lambda x: randomTranspose(x)),
Lambda(lambda x: toTensor(x)),
Normalize(mean=mean, std=std)
]
Expand Down Expand Up @@ -107,6 +107,11 @@ def get_optimizer(net, lr, resnet=False, pretrained=False):
return optimizer


def load_net(net, name):
state_dict = torch.load('../models/{}.pth'.format(name))
net.load_state_dict(state_dict)


def train_baselines():

train_data, val_data = get_dataloader(96)
Expand All @@ -117,27 +122,23 @@ def train_baselines():
print(' epoch iter rate | smooth_loss | train_loss (acc) | valid_loss (acc) | total_train_loss\n')
logger = Logger('../log/{}'.format(name), name)

# load pre-trained model on train-37479
net = model(pretrained=True)
optimizer = get_optimizer(net, lr=.01, pretrained=True, resnet=True if 'resnet' in name else False)
load_net(net, name)
optimizer = get_optimizer(net, lr=.001, pretrained=True, resnet=True if 'resnet' in name else False)
net = nn.DataParallel(net.cuda())

train_data.batch_size = batch
val_data.batch_size = batch

num_epoches = 50 #100
num_epoches = 40
print_every_iter = 20
epoch_test = 1

# optimizer
# optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=0.0005)
# optimizer = optim.Adam(net.parameters(), lr=1e-4, weight_decay=5e-4)

smooth_loss = 0.0
train_loss = np.nan
train_acc = np.nan
# test_loss = np.nan
best_test_loss = np.inf
# test_acc = np.nan
t = time.time()

for epoch in range(num_epoches): # loop over the dataset multiple times
Expand Down
10 changes: 5 additions & 5 deletions util.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,16 +104,16 @@ def get_learning_rate(optimizer):
return lr


def lr_schedule(epoch, optimizer, pretrained=False):
def lr_schedule(epoch, optimizer, base_lr=0.1, pretrained=False):
if pretrained:
if 0 <= epoch < 10:
lr = 1e-2
lr = base_lr
elif 10 <= epoch < 25:
lr = 5e-3
lr = base_lr * 0.5
elif 25 <= epoch < 40:
lr = 1e-3
lr = base_lr * 0.1
else:
lr = 1e-4
lr = base_lr * 0.01
else:
if 0 <= epoch < 10:
lr = 1e-1
Expand Down

0 comments on commit 4ca1625

Please sign in to comment.