-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathtest.py
executable file
·195 lines (156 loc) · 7.77 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
import os
import math
import torch
import torch.nn as nn
import traceback
import pandas as pd
import time
import numpy as np
import argparse
from utils.generic_utils import load_config, save_config_file
from utils.generic_utils import set_init_dict
from utils.generic_utils import NoamLR, binary_acc
from utils.generic_utils import save_best_checkpoint
from utils.tensorboard import TensorboardWriter
from utils.dataset import test_dataloader
from models.spiraconv import SpiraConvV1, SpiraConvV2
from utils.audio_processor import AudioProcessor
import random
# set random seed
random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed(42)
np.random.seed(42)
def test(criterion, ap, model, c, testloader, step, cuda, confusion_matrix=False):
padding_with_max_lenght = c.dataset['padding_with_max_lenght']
losses = []
accs = []
model.zero_grad()
model.eval()
loss = 0
acc = 0
preds = []
targets = []
with torch.no_grad():
for feature, target, slices, targets_org in testloader:
#try:
if cuda:
feature = feature.cuda()
target = target.cuda()
output = model(feature).float()
# output = torch.round(output * 10**4) / (10**4)
# Calculate loss
if not padding_with_max_lenght and not c.dataset['split_wav_using_overlapping']:
target = target[:, :output.shape[1],:target.shape[2]]
if c.dataset['split_wav_using_overlapping']:
# unpack overlapping for calculation loss and accuracy
if slices is not None and targets_org is not None:
idx = 0
new_output = []
new_target = []
for i in range(slices.size(0)):
num_samples = int(slices[i].cpu().numpy())
samples_output = output[idx:idx+num_samples]
output_mean = samples_output.mean()
samples_target = target[idx:idx+num_samples]
target_mean = samples_target.mean()
new_target.append(target_mean)
new_output.append(output_mean)
idx += num_samples
target = torch.stack(new_target, dim=0)
output = torch.stack(new_output, dim=0)
#print(target, targets_org)
if cuda:
output = output.cuda()
target = target.cuda()
targets_org = targets_org.cuda()
if not torch.equal(targets_org, target):
raise RuntimeError("Integrity problem during the unpack of the overlay for the calculation of accuracy and loss. Check the dataloader !!")
loss += criterion(output, target).item()
# calculate binnary accuracy
y_pred_tag = torch.round(output)
acc += (y_pred_tag == target).float().sum().item()
preds += y_pred_tag.reshape(-1).int().cpu().numpy().tolist()
targets += target.reshape(-1).int().cpu().numpy().tolist()
if confusion_matrix:
print("======== Confusion Matrix ==========")
y_target = pd.Series(targets, name='Target')
y_pred = pd.Series(preds, name='Predicted')
df_confusion = pd.crosstab(y_target, y_pred, rownames=['Target'], colnames=['Predicted'], margins=True)
print(df_confusion)
mean_acc = acc / len(testloader.dataset)
mean_loss = loss / len(testloader.dataset)
print("Test\n Loss:", mean_loss, "Acurracy: ", mean_acc)
return mean_acc
def run_test(args, checkpoint_path, testloader, c, model_name, ap, cuda=True):
# define loss function
criterion = nn.BCELoss(reduction='sum')
padding_with_max_lenght = c.dataset['padding_with_max_lenght']
if(model_name == 'spiraconv_v1'):
model = SpiraConvV1(c)
elif (model_name == 'spiraconv_v2'):
model = SpiraConvV2(c)
#elif(model_name == 'voicesplit'):
else:
raise Exception(" The model '"+model_name+"' is not suported")
if c.train_config['optimizer'] == 'adam':
optimizer = torch.optim.Adam(model.parameters(),
lr=c.train_config['learning_rate'])
else:
raise Exception("The %s not is a optimizer supported" % c.train['optimizer'])
step = 0
if checkpoint_path is not None:
print("Loading checkpoint: %s" % checkpoint_path)
try:
checkpoint = torch.load(checkpoint_path, map_location='cpu')
model.load_state_dict(checkpoint['model'])
print("Model Sucessful Load !")
except Exception as e:
raise ValueError("You need pass a valid checkpoint, may be you need check your config.json because de the of this checkpoint cause the error: "+ e)
step = checkpoint['step']
else:
raise ValueError("You need pass a checkpoint_path")
# convert model from cuda
if cuda:
model = model.cuda()
model.train(False)
test_acc = test(criterion, ap, model, c, testloader, step, cuda=cuda, confusion_matrix=True)
if __name__ == '__main__':
# python test.py --test_csv ../SPIRA_Dataset_V1/metadata_test.csv -r ../SPIRA_Dataset_V1/ --checkpoint_path ../checkpoints/spiraconv-trained-with-SPIRA_Dataset_V1/spiraconv/checkpoint_1068.pt --config_path ../checkpoints/spiraconv-trained-with-SPIRA_Dataset_V1/spiraconv/config.json --batch_size 5 --num_workers 2 --no_insert_noise True
parser = argparse.ArgumentParser()
parser.add_argument('-t', '--test_csv', type=str, required=True,
help="test csv example: ../SPIRA_Dataset_V1/metadata_test.csv")
parser.add_argument('-r', '--test_root_dir', type=str, required=True,
help="Test root dir example: ../SPIRA_Dataset_V1/")
parser.add_argument('-c', '--config_path', type=str, required=True,
help="json file with configurations get in checkpoint path")
parser.add_argument('--checkpoint_path', type=str, default=None, required=True,
help="path of checkpoint pt file, for continue training")
parser.add_argument('--batch_size', type=int, default=20,
help="Batch size for test")
parser.add_argument('--num_workers', type=int, default=10,
help="Number of Workers for test data load")
parser.add_argument('--no_insert_noise', type=bool, default=False,
help=" No insert noise in test ?")
parser.add_argument('--num_noise_control', type=int, default=1,
help="Number of Noise for insert in control")
parser.add_argument('--num_noise_patient', type=int, default=0,
help="Number of Noise for insert in patient")
args = parser.parse_args()
c = load_config(args.config_path)
ap = AudioProcessor(**c.audio)
if not args.no_insert_noise:
c.data_aumentation['insert_noise'] = True
else:
c.data_aumentation['insert_noise'] = False
# ste values for noisy insertion in test
c.data_aumentation["num_noise_control"] = args.num_noise_control
c.data_aumentation["num_noise_patient"] = args.num_noise_patient
print("Insert noise ?", c.data_aumentation['insert_noise'])
c.dataset['test_csv'] = args.test_csv
c.dataset['test_data_root_path'] = args.test_root_dir
c.test_config['batch_size'] = args.batch_size
c.test_config['num_workers'] = args.num_workers
max_seq_len = c.dataset['max_seq_len']
test_dataloader = test_dataloader(c, ap, max_seq_len=max_seq_len)
run_test(args, args.checkpoint_path, test_dataloader, c, c.model_name, ap, cuda=True)