Skip to content

Commit

Permalink
densenet161 0.9290
Browse files Browse the repository at this point in the history
  • Loading branch information
jxu7 committed Jun 19, 2017
1 parent 5dcca27 commit 758af8c
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
8 changes: 4 additions & 4 deletions baseline_ensembles.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,16 +86,16 @@ def probs(dataloader):
n_models = len(models)
n_imgs = dataloader.dataset.num
imgs = dataloader.dataset.images.copy()
probabilities = np.empty(n_transforms, n_models, n_imgs, 17)
probabilities = np.empty((n_transforms, n_models, n_imgs, 17))
for t_idx, transform in enumerate(transforms):
t_name = str(transform).split()[1]
dataloader.dataset.images = transform(imgs)
for m_idx, model in enumerate(models):
name = str(model).split()[1]
net = model().cuda()
net = nn.DataParallel(net)
net = net.load_state_dict(torch.load('models/{}.pth'.format(name)))

net.load_state_dict(torch.load('models/{}.pth'.format(name)))
net.eval()
# predict
m_predictions = predict(net, dataloader)

Expand All @@ -121,5 +121,5 @@ def probs(dataloader):
height=256,
width=256
)
valid_dataloader = DataLoader(validation, batch_size=512, shuffle=False)
valid_dataloader = DataLoader(validation, batch_size=256, shuffle=False)
print(probs(valid_dataloader))
4 changes: 2 additions & 2 deletions util.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ def idx_name():

def predict(net, dataloader):
num = dataloader.dataset.num
probs = np.empty(num, 17)
probs = np.empty((num, 17))
current = 0
for batch_idx, (images, im_ids) in enumerate(dataloader):
for batch_idx, (images, im_ids, _) in enumerate(dataloader):
num = images.size(0)
previous = current
current = previous + num
Expand Down

0 comments on commit 758af8c

Please sign in to comment.