-
Notifications
You must be signed in to change notification settings - Fork 0
/
predict_challenge.py
60 lines (51 loc) · 1.86 KB
/
predict_challenge.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
'''
EECS 445 - Introduction to Machine Learning
Fall 2018 - Project 2
Predict Challenge
Runs the challenge model inference on the test dataset and saves the
predictions to disk
Usage: python predict_challenge.py --uniqname=<uniqname>
'''
import argparse
import torch
import numpy as np
import pandas as pd
import utils
from dataset import get_train_val_test_loaders
from model.challenge import ResidualBlock, Challenge
# from model.challenge import Challenge
from train_common import *
from utils import config
def predict_challenge(data_loader, model):
"""
Runs the model inference on the test set and outputs the predictions
"""
model_pred = np.array([])
for i, (X, y) in enumerate(data_loader):
output = model(X)
predicted = predictions(output.data)
predicted = predicted.numpy()
model_pred = np.concatenate([model_pred, predicted])
return model_pred
def main(uniqname):
# data loaders
_, _, te_loader, get_semantic_label = get_train_val_test_loaders(
num_classes=config('challenge.num_classes'))
####
model = Challenge(ResidualBlock, [2,2,2,2])
# criterion = torch.nn.CrossEntropyLoss()
# optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
####
# Attempts to restore the latest checkpoint if exists
model, _, _ = restore_checkpoint(model, config('challenge.checkpoint'))
# Evaluate model
model_pred = predict_challenge(te_loader, model)
print('saving challenge predictions...\n')
model_pred = [get_semantic_label(p) for p in model_pred]
pd_writer = pd.DataFrame(model_pred, columns=['predictions'])
pd_writer.to_csv(uniqname + '.csv', index=False, header=False)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--uniqname', required=True)
args = parser.parse_args()
main(args.uniqname)