Skip to content

Commit

Permalink
pre-trained densenet121
Browse files Browse the repository at this point in the history
  • Loading branch information
jxu7 committed Jun 3, 2017
1 parent 12d2561 commit d10fd53
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 47 deletions.
18 changes: 18 additions & 0 deletions planet_models/densenet_planet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import torch.utils.model_zoo as model_zoo
from torchvision.models.densenet import model_urls
from torchvision.models import DenseNet


def densenet121(num_classes=17, pretrained=False):
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), num_classes=num_classes)
if pretrained:
# load model dictionary
model_dict = model.state_dict()
# load pretrained model
pretrained_dict = model_zoo.load_url(model_urls['densenet121'])
# update model dictionary using pretrained model without classifier layer
model_dict.update({key: pretrained_dict[key] for key in pretrained_dict.keys() if 'classifier' not in key})
model.load_state_dict(model_dict)
return model


4 changes: 2 additions & 2 deletions test_ensembles.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


SIMPLENET = 'models/simplenet_v3.1.pth'
RESNET = 'models/densenet121.pth'
RESNET = 'models/pretrained_densenet121.pth'

def test():
resnet = nn.DataParallel(densenet121().cuda())
Expand All @@ -22,7 +22,7 @@ def test():
simple_v2.load_state_dict(torch.load(SIMPLENET))
simple_v2.eval()

name = 'ensembles_simple_v3.1_densenet121'
name = 'ensembles_simple_v3.1_pretrained_densenet121'
resnet_loader = test_jpg_loader(512, transform=Compose(
[
Scale(224),
Expand Down
45 changes: 22 additions & 23 deletions trainers/optimize_threshold.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ def get_labels(prediction, t):
return pred

large_net = nn.DataParallel(densenet121().cuda())
large_net.load_state_dict(torch.load('../models/densenet121.pth'))
large_net.load_state_dict(torch.load('../models/pretrained_densenet121.pth'))
large_net.eval()
# simple_v2 = nn.DataParallel(SimpleNetV3().cuda())
# simple_v2.load_state_dict(torch.load('../models/simplenet_v3.1.pth'))
# simple_v2.eval()
simple_v2 = nn.DataParallel(SimpleNetV3().cuda())
simple_v2.load_state_dict(torch.load('../models/simplenet_v3.1.pth'))
simple_v2.eval()

resnet_data = validation_jpg_loader(512, transform=Compose([
Scale(224),
Expand All @@ -34,39 +34,38 @@ def get_labels(prediction, t):
Normalize(mean, std)
]))

# simplenet_data = validation_jpg_loader(
# 512, transform=Compose(
# [
# Scale(72),
# RandomHorizontalFlip(),
# ToTensor(),
# Normalize(mean, std)
# ]
# )
# )
num_class = 17
simplenet_data = validation_jpg_loader(
512, transform=Compose(
[
Scale(72),
RandomHorizontalFlip(),
ToTensor(),
Normalize(mean, std)
]
)
)
pred = []
targets = []
# predict
# for batch_index, ((resnet_img, resnet_target), (simplenet_img, simplenet_target)) \
# in enumerate(zip(resnet_data, simplenet_data)):
for batch_index, (simplenet_img, simplenet_target) in enumerate(resnet_data):
resnet_output = evaluate(large_net, simplenet_img)
for batch_index, ((resnet_img, resnet_target), (simplenet_img, simplenet_target)) \
in enumerate(zip(resnet_data, simplenet_data)):
#for batch_index, (simplenet_img, simplenet_target) in enumerate(simplenet_data):
resnet_output = evaluate(large_net, resnet_img)
# resnet_output = F.sigmoid(resnet_output)

# simplenet_output = evaluate(simple_v2, simplenet_img)
simplenet_output = evaluate(simple_v2, simplenet_img)
# simplenet_output = F.sigmoid(simplenet_output)

# output = F.sigmoid((simplenet_output + resnet_output)/2)
output = F.sigmoid(resnet_output)
output = F.sigmoid((simplenet_output + resnet_output)/2)
# output = F.sigmoid(output)
pred.append(output.data.cpu().numpy())
targets.append(simplenet_target.cpu().numpy())

pred = np.vstack(pred)
targets = np.vstack(targets)
threshold = [0.2] * 17
# optimize
for i in range(num_class):
for i in range(17):
best_thresh = 0.0
best_score = 0.0
for r in range(resolution):
Expand Down
31 changes: 10 additions & 21 deletions trainers/train_densenet.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
from torchvision.models.densenet import DenseNet
from datasets import train_jpg_loader, validation_jpg_loader
import torch
from torch.nn import functional as F
from torch.nn import *
from util import *
from torch import optim
from torchvision.transforms import *
import torch.utils.model_zoo as model_zoo
from torchvision.models.densenet import model_urls
from planet_models.densenet_planet import densenet121

NAME = 'densenet121'

NAME = 'pretrained_densenet121'


class RandomVerticalFLip(object):
Expand All @@ -19,27 +15,20 @@ def __call__(self, img):
return img


def densenet121(num_classes=17, pretrained=False):
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), num_classes=num_classes)
if pretrained:
state_dict = model_zoo.load_url(model_urls[NAME])

model.load_state_dict(model_zoo.load_url(model_urls[NAME]))
return model


def lr_scheduler(optimizer, epoch):
if epoch % 10 == 0 and epoch != 0:
for param_group in optimizer.param_groups:
param_group['lr'] = param_group['lr'] * 0.1
def get_optimizer(model, lr=1e-4, weight_decay=1e-4):
params = [
{'params': model.features.parameters(), 'lr': lr},
{'params': model.classifier.parameters(), 'lr': lr * 10}
]
return optim.Adam(params=params, weight_decay=weight_decay)


def train(epoch):
criterion = MultiLabelSoftMarginLoss()
net = densenet121()
logger = Logger('../log/', NAME)
# optimizer = optim.Adam(lr=5e-4, params=net.parameters())
optimizer = optim.Adam(lr=1e-4, params=net.parameters(), weight_decay=5e-5)
optimizer = get_optimizer(net)
net.cuda()
net = torch.nn.DataParallel(net, device_ids=[0, 1])
# resnet.load_state_dict(torch.load('../models/simplenet_v3.pth'))
Expand Down
2 changes: 1 addition & 1 deletion util.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from planet_models.resnet_planet import resnet14_planet
from planet_models.simplenet_v2 import SimpleNetV2

BEST_THRESHOLD= [0.233, 0.065, 0.196, 0.315, 0.226, 0.202, 0.108, 0.185, 0.285, 0.14, 0.292, 0.238, 0.194, 0.35, 0.196, 0.145, 0.369]
BEST_THRESHOLD= [0.16, 0.093, 0.203, 0.203, 0.241, 0.175, 0.119, 0.225, 0.134, 0.074, 0.141, 0.22, 0.073, 0.184, 0.167, 0.049, 0.071]


def evaluate(model, image):
Expand Down

0 comments on commit d10fd53

Please sign in to comment.