Skip to content

Commit

Permalink
try majority voting
Browse files Browse the repository at this point in the history
  • Loading branch information
Junhong Xu committed Jun 26, 2017
1 parent 5a2dd12 commit 14ba0a2
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 26 deletions.
98 changes: 74 additions & 24 deletions baseline_ensembles.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,36 @@ def verticalFlip(imgs):

mean = [0.31151703, 0.34061992, 0.29885209]
std = [0.16730586, 0.14391145, 0.13747531]
# threshold = [0.23166666666666666, 0.19599999999999998, 0.18533333333333335,
# 0.08033333333333334, 0.20199999999999999, 0.16866666666666666,
# 0.20533333333333334, 0.27366666666666667, 0.2193333333333333,
# 0.21299999999999999, 0.15666666666666665, 0.096666666666666679,
# 0.21933333333333335, 0.058666666666666673, 0.19033333333333333,
# 0.25866666666666666, 0.057999999999999996] # resnet-152

# threshold = [ 0.18533333, 0.18866667, 0.13533333, 0.03633333, 0.221, 0.17666667,
# 0.231, 0.23933333, 0.21966667, 0.169, 0.23333333, 0.21833333,
# 0.24033333, 0.112, 0.40233333, 0.31833333, 0.237] # densenet-161
thresholds = [
[], # resnet-18
[], # resnet-34
[], # resnet-50
[0.23166666666666666, 0.19599999999999998, 0.18533333333333335,
0.08033333333333334, 0.20199999999999999, 0.16866666666666666,
0.20533333333333334, 0.27366666666666667, 0.2193333333333333,
0.21299999999999999, 0.15666666666666665, 0.096666666666666679,
0.21933333333333335, 0.058666666666666673, 0.19033333333333333,
0.25866666666666666, 0.057999999999999996], # resnet-152
[], # densenet-121
[0.18533333, 0.18866667, 0.13533333,
0.03633333, 0.221, 0.17666667,
0.231, 0.23933333, 0.21966667,
0.169, 0.23333333, 0.21833333,
0.24033333, 0.112, 0.40233333,
0.31833333, 0.237], # densenet-161
[], # densenet-169
]
threshold = [0.23166666666666666, 0.19599999999999998, 0.18533333333333335,
0.08033333333333334, 0.20199999999999999, 0.16866666666666666,
0.20533333333333334, 0.27366666666666667, 0.2193333333333333,
0.21299999999999999, 0.15666666666666665, 0.096666666666666679,
0.21933333333333335, 0.058666666666666673, 0.19033333333333333,
0.25866666666666666, 0.057999999999999996] # resnet-152

threshold = [ 0.18533333, 0.18866667, 0.13533333, 0.03633333, 0.221, 0.17666667,
0.231, 0.23933333, 0.21966667, 0.169, 0.23333333, 0.21833333,
0.24033333, 0.112, 0.40233333, 0.31833333, 0.237] # densenet-161

# threshold = [ 0.17733333, 0.213, 0.15766667, 0.049, 0.28733333, 0.18066667,
# 0.19666667, 0.212, 0.21566667, 0.17233333, 0.16466667, 0.274,
Expand All @@ -74,9 +94,10 @@ def verticalFlip(imgs):

transforms = [default, rotate90, rotate180, rotate270, verticalFlip, horizontalFlip]

models = [# resnet18_planet, resnet34_planet, resnet50_planet, densenet121, densenet161, densenet169
# resnet152_planet
# densenet121,
models = [ resnet18_planet,
resnet34_planet,
resnet50_planet,
densenet121,
densenet169,
densenet161,
resnet152_planet
Expand Down Expand Up @@ -134,8 +155,6 @@ def find_best_threshold(labels, probabilities):
best_thresh = r
best_score = score
t[i] = best_thresh
print('Transform index {}, score {}, threshold {}, label {}'.format(t_idx, best_score, best_thresh, i))
print('Transform index {}, threshold {}, score {}'.format(t_idx, t, best_score))
threshold = threshold + t
acc += best_score
print('AVG ACC,', acc/len(transforms))
Expand Down Expand Up @@ -185,7 +204,7 @@ def do_thresholding(names, labels):
return t


def get_files(excludes=['resnet18']):
def get_files(excludes=None):
file_names = glob.glob('probs/*.txt')
names = []
for filename in file_names:
Expand All @@ -194,7 +213,34 @@ def get_files(excludes=['resnet18']):
return names


def predict_test(t):
def predict_test_majority():
"""
Majority voting method.
"""
labels = np.empty((len(models), 61191, 17))
for m_idx, model in models:
name = str(model).split()[1]
net = nn.DataParallel(model().cuda())
net.load_state_dict(torch.load('models/{}.pth').format(name))
net.eval()
preds = np.zeros(61191, 17)
for t in transforms:
test_dataloader.dataset.images = t(test_dataloader.dataset.images)
pred = predict(net, dataloader=test_dataloader)
preds = preds + pred
# get predictions for the single model
preds = preds/len(transforms)
# get labels
preds = (preds > thresholds[m_idx]).astype(int)
labels[m_idx] = preds

# majority voting
labels = labels.sum(axis=0)
labels = (labels >= len(models)//2).astype(int)
pred_csv(predictions=labels, name='majority_voting_ensembles')


def predict_test_averaging(t):
preds = np.zeros((61191, 17))
# imgs = test_dataloader.dataset.images.copy()
# iterate over models
Expand All @@ -216,16 +262,20 @@ def predict_test(t):


if __name__ == '__main__':
# valid_dataloader = get_validation_loader()
test_dataloader = get_test_dataloader()
valid_dataloader = get_validation_loader()
# test_dataloader = get_test_dataloader()

# save results to files
# probabilities = probs(valid_dataloader)

# get threshold
# file_names = get_files(['resnet18', 'resnet50', 'resnet34'])
# t = do_thresholding(file_names, valid_dataloader.dataset.labels)
# print(t)

# testing
predict_test(threshold)
model_names = ['resnet18', 'resnet34', 'resnet50', 'resnet151', 'densenet121', 'densenet161', 'densenet169']
for m in models:
name = str(m).split()[1]
file_names = get_files([n for n in model_names if n != name])
print('Model {}'.format(name))
t = do_thresholding(file_names, valid_dataloader.dataset.labels)
print(t)

# average testing
# predict_test_averaging(threshold)
5 changes: 3 additions & 2 deletions util.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,16 @@ def predict(net, dataloader):
return probs


def pred_csv(predictions, threshold, name):
def pred_csv(predictions, name, threshold=None):
"""
predictions: numpy array of predicted probabilities
"""
csv_name = os.path.join(KAGGLE_DATA_DIR, 'sample_submission.csv')
submission = pd.read_csv(csv_name)
print(submission)
for i, pred in enumerate(predictions):
labels = (pred > threshold).astype(int)
if threshold is not None:
labels = (pred > threshold).astype(int)
labels = np.where(labels == 1)[0]
labels = ' '.join(idx_name()[index] for index in labels)
submission['tags'][i] = labels
Expand Down

0 comments on commit 14ba0a2

Please sign in to comment.