Skip to content

Commit

Permalink
train densenet with dropout rate 0.2.
Browse files Browse the repository at this point in the history
  • Loading branch information
jxu7 committed Jun 4, 2017
1 parent d10fd53 commit 60793df
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 26 deletions.
3 changes: 2 additions & 1 deletion planet_models/densenet_planet.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@


def densenet121(num_classes=17, pretrained=False):
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), num_classes=num_classes)
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), drop_rate=0.2,
num_classes=num_classes)
if pretrained:
# load model dictionary
model_dict = model.state_dict()
Expand Down
2 changes: 1 addition & 1 deletion test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from util import BEST_THRESHOLD


MODEL='models/densenet121.pth'
MODEL='models/pretrained_densenet121.pth'


def test(model_dir, transform):
Expand Down
40 changes: 20 additions & 20 deletions trainers/optimize_threshold.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ def get_labels(prediction, t):
large_net = nn.DataParallel(densenet121().cuda())
large_net.load_state_dict(torch.load('../models/pretrained_densenet121.pth'))
large_net.eval()
simple_v2 = nn.DataParallel(SimpleNetV3().cuda())
simple_v2.load_state_dict(torch.load('../models/simplenet_v3.1.pth'))
simple_v2.eval()
# simple_v2 = nn.DataParallel(SimpleNetV3().cuda())
# simple_v2.load_state_dict(torch.load('../models/simplenet_v3.1.pth'))
# simple_v2.eval()

resnet_data = validation_jpg_loader(512, transform=Compose([
Scale(224),
Expand All @@ -34,30 +34,30 @@ def get_labels(prediction, t):
Normalize(mean, std)
]))

simplenet_data = validation_jpg_loader(
512, transform=Compose(
[
Scale(72),
RandomHorizontalFlip(),
ToTensor(),
Normalize(mean, std)
]
)
)
# simplenet_data = validation_jpg_loader(
# 512, transform=Compose(
# [
# Scale(72),
# RandomHorizontalFlip(),
# ToTensor(),
# Normalize(mean, std)
# ]
# )
# )
pred = []
targets = []
# predict
for batch_index, ((resnet_img, resnet_target), (simplenet_img, simplenet_target)) \
in enumerate(zip(resnet_data, simplenet_data)):
#for batch_index, (simplenet_img, simplenet_target) in enumerate(simplenet_data):
resnet_output = evaluate(large_net, resnet_img)
# for batch_index, ((resnet_img, resnet_target), (simplenet_img, simplenet_target)) \
# in enumerate(zip(resnet_data, simplenet_data)):
for batch_index, (simplenet_img, simplenet_target) in enumerate(resnet_data):
output = evaluate(large_net, simplenet_img)
# resnet_output = F.sigmoid(resnet_output)

simplenet_output = evaluate(simple_v2, simplenet_img)
# simplenet_output = evaluate(simple_v2, simplenet_img)
# simplenet_output = F.sigmoid(simplenet_output)

output = F.sigmoid((simplenet_output + resnet_output)/2)
# output = F.sigmoid(output)
# output = F.sigmoid((simplenet_output + resnet_output)/2)
output = F.sigmoid(output)
pred.append(output.data.cpu().numpy())
targets.append(simplenet_target.cpu().numpy())

Expand Down
6 changes: 3 additions & 3 deletions trainers/train_densenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def __call__(self, img):
return img


def get_optimizer(model, lr=1e-4, weight_decay=1e-4):
def get_optimizer(model, lr=5e-5, weight_decay=5e-4):
params = [
{'params': model.features.parameters(), 'lr': lr},
{'params': model.classifier.parameters(), 'lr': lr * 10}
Expand All @@ -32,7 +32,7 @@ def train(epoch):
net.cuda()
net = torch.nn.DataParallel(net, device_ids=[0, 1])
# resnet.load_state_dict(torch.load('../models/simplenet_v3.pth'))
train_data_set = train_jpg_loader(100, transform=Compose(
train_data_set = train_jpg_loader(90, transform=Compose(
[

Scale(256),
Expand All @@ -45,7 +45,7 @@ def train(epoch):
))
validation_data_set = validation_jpg_loader(64, transform=Compose(
[
Scale(256),
Scale(224),
ToTensor(),
Normalize(mean, std)
]
Expand Down
2 changes: 1 addition & 1 deletion util.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from planet_models.resnet_planet import resnet14_planet
from planet_models.simplenet_v2 import SimpleNetV2

BEST_THRESHOLD= [0.16, 0.093, 0.203, 0.203, 0.241, 0.175, 0.119, 0.225, 0.134, 0.074, 0.141, 0.22, 0.073, 0.184, 0.167, 0.049, 0.071]
BEST_THRESHOLD= [0.166, 0.082, 0.237, 0.178, 0.242, 0.223, 0.168, 0.173, 0.197, 0.107, 0.14, 0.242, 0.103, 0.17, 0.273, 0.147, 0.09 ]


def evaluate(model, image):
Expand Down

0 comments on commit 60793df

Please sign in to comment.