Skip to content

Commit

Permalink
train baseline models
Browse files Browse the repository at this point in the history
  • Loading branch information
jxu7 committed Jun 15, 2017
1 parent 04e69ab commit 706d5c7
Show file tree
Hide file tree
Showing 7 changed files with 18,009 additions and 7 deletions.
3,000 changes: 3,000 additions & 0 deletions probs/densenet121.txt

Large diffs are not rendered by default.

3,000 changes: 3,000 additions & 0 deletions probs/densenet161.txt

Large diffs are not rendered by default.

3,000 changes: 3,000 additions & 0 deletions probs/densenet169.txt

Large diffs are not rendered by default.

3,000 changes: 3,000 additions & 0 deletions probs/resnet18_planet.txt

Large diffs are not rendered by default.

3,000 changes: 3,000 additions & 0 deletions probs/resnet34_planet.txt

Large diffs are not rendered by default.

3,000 changes: 3,000 additions & 0 deletions probs/resnet50_planet.txt

Large diffs are not rendered by default.

16 changes: 9 additions & 7 deletions util.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,22 @@ def save_results(models, dataloader):
"""Given model/models, this function saves the result of F.sigmoid(model(x))"""
for model in models:
name = str(model).split()[1]

# create
model = model()
model = nn.DataParallel(model.cuda())
# net = model()
# net = nn.DataParallel(net.cuda())# nn.DataParallel(densenet169())
# net.load_state_dict(torch.load('models/%s.pth' % name)['state_dic'
net = torch.load('models/%s.pth' % name)
net.eval()
# model = nn.DataParallel(model.cuda())

# load
model.load_state_dict(torch.load('models/{}.pth'.format(name)))

# forward
result = []
for i, (image, index) in enumerate(dataloader):
for i, (image, target, index) in enumerate(dataloader):
image = Variable(image.cuda(), volatile=True)
# N * 17
probs = F.sigmoid(model(image))
probs = F.sigmoid(net(image))
result.append(probs.data.cpu().numpy())

# concatenate the probabilities
Expand Down Expand Up @@ -239,4 +241,4 @@ def save_time(self, start_time, end_time):
)
dataloader = DataLoader(validation)
save_results([resnet18_planet, resnet34_planet, resnet50_planet,
densenet121, densenet169, densenet161,], dataloader)
densenet121, densenet169, densenet161,], dataloader)

0 comments on commit 706d5c7

Please sign in to comment.