-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathCV_REmodel.py
75 lines (58 loc) · 2.55 KB
/
CV_REmodel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import torch
from data.data_loader import load_data
from model.reading_embedding_model import ReadingEmbeddingModel
from utils.cross_validation import k_fold_CV
from metrics.metrics import total_loss, relevance_accuracy
from utils.plotting import plot_tsne, plot_cm, plot_roc_auc
'''
Train the Model On Real Data
'''
'''
Load Real Data
- embeddings : [num_sentences, max_num_words, hidden_size]
- eeg_features : [num_sentences, max_num_words, eeg_feature_size = 5460]
- gaze_features : [num_sentences, max_num_words, num_measure_mode = 12]
- labels : [num_sentences, max_num_words]
'''
### Load the data of of i-th subject
idx = 0 #### 0 - 8 for 9 different subjects
feat_choice = [0, 1, 0]
epochs = 500
n_translayer = 1
opt_type = 'SGD' # or 'Adam'
min_lr = -1 # not to use lr scheduler when -1
downsampled_data = load_data(downsample = True, subIdx = idx) #### Down Sample or not
original_data = load_data(downsample = False, subIdx = idx)
embeddings, eeg_features, gaze_features, labels, sen_len = original_data
#d_embeddings, d_eeg_features, d_gaze_features, d_labels, d_sen_len = downsampled_data
# total number of 1s after downsample
#print('Total # of 1s after downsample = ', downsampled_data[-2].sum())
# total number of 0s before downsample
#print('Total # of 0s before downsample = ', sum(l - sum(s) for s, l in zip(original_data[-2], original_data[-1])))
# total number of 0s after downsample
#print('Total # of 0s after downsample = ', sum(l - sum(s) for s, l in zip(original_data[-2], original_data[-1])))
'''
# Device setup
'''
seed = 42
torch.manual_seed(seed)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# seed gpus if available
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
print(device)
avg_acc, test_acc_list, overall_cm, overall_f1, (overall_preds, overall_labels) = k_fold_CV(k_fold = 5,
batch_size = 40, epochs = epochs,
data = original_data, downsample = True,
n_trans = n_translayer, feat_choice = feat_choice,
metrics = [total_loss, relevance_accuracy],
opt_type = opt_type,
lr = 0.05, min_lr = min_lr,
MultiHeaded = True, num_heads = 3, device = device)
print('Average Acc = ', avg_acc)
print('Overall F1 Score = ', overall_f1)
# Plot ROC-AUC
plot_roc_auc(overall_labels, overall_preds)
# Plot Confusion Matrix
plot_cm(overall_cm, 'Confusion Matrix on Training Data')