Skip to content

Commit

Permalink
Updated everything to pytorch 0.41. And Gitignore.
Browse files Browse the repository at this point in the history
  • Loading branch information
sercharpak committed Dec 4, 2018
1 parent b9d2171 commit 5a4c149
Show file tree
Hide file tree
Showing 13 changed files with 543 additions and 14 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
*~
*checkpoint.ipynb
__pycache__
*.pyc
Binary file added EEG-BCI-Project/classErr_baselines_500.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added EEG-BCI-Project/classErr_deeper_500.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
25 changes: 13 additions & 12 deletions EEG-BCI-Project/cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ def train_model(model, criterion, optimizer, scheduler,train_input, train_target
"""
# Augment the data if necessary and make Variables
train_input, train_target = augment_train(train_input, train_target, n_augmentation, 0.1, 2 , 0.2,verbose)
train_input, train_target = Variable(train_input), Variable(train_target)
if(save_loss):
val_input,val_target = Variable(val_input),Variable(val_target)
#train_input, train_target = Variable(train_input), Variable(train_target) # No longer necessary in Pytorch 0.4
#if(save_loss): # No longer necessary in Pytorch 0.4
# val_input,val_target = Variable(val_input),Variable(val_target) # No longer necessary in Pytorch 0.4

for e in range(n_epochs):
scheduler.step() # decrease the learning rate
Expand All @@ -88,10 +88,10 @@ def train_model(model, criterion, optimizer, scheduler,train_input, train_target
optimizer.step()
# save loss data, if save_loss is true
if(save_loss):
tr_loss.append(criterion(model(train_input),train_target).data[0])
val_loss.append(criterion(model(val_input),val_target).data[0])
tr_err.append(evaluate_error(model,train_input.data,train_target.data))
val_err.append(evaluate_error(model,val_input.data,val_target.data))
tr_loss.append(criterion(model(train_input),train_target).item())#item and detach from pytorch 0.4
val_loss.append(criterion(model(val_input),val_target).item())#item and detach from pytorch 0.4
tr_err.append(evaluate_error(model,train_input.detach(),train_target.detach()).item()) #item and detach from pytorch 0.4
val_err.append(evaluate_error(model,val_input.detach(),val_target.detach()).item()) #item and detach from pytorch 0.4
if(verbose == 1):
print("Training ended successfully after {} epochs".format(n_epochs))

Expand All @@ -103,13 +103,14 @@ def evaluate_error(model, data_input, data_target):
data_target -> the correct labels for the input data
Outputs: the number of missclassified samples
"""
data_input, data_target = Variable(data_input,volatile=True),Variable(data_target,volatile=True)
nb_errors = 0
#data_input, data_target = Variable(data_input,volatile=True),Variable(data_target,volatile=True)# No longer necessary in Pytorch 0.4
with torch.no_grad():
nb_errors = 0

output = model(data_input)
nb_errors += (output.max(1)[1] != data_target).long().data.sum() # count the number of samples in the output with labels different than in the target
output = model(data_input)
nb_errors += (output.max(1)[1] != data_target).long().detach().sum() # count the number of samples in the output with labels different than in the target #detach is for pytorch 0.4

return nb_errors/data_input.size(0) # take the mean, to get a fraction of missclassified samples
return nb_errors.float()/data_input.size(0) # take the mean, to get a fraction of missclassified samples #pytorch 0.4 need to be float to be able to divide.

def cross_validate(model,criterion,optimizer,scheduler,dataset,target,k_fold,n_epochs=250,batch_size = 50,n_augmentation =0,verbose=0):
"""
Expand Down
1 change: 1 addition & 0 deletions EEG-BCI-Project/cv_500t.txt

Large diffs are not rendered by default.

100 changes: 100 additions & 0 deletions EEG-BCI-Project/data_bci/labels_data_set_iv.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
1
0
0
0
1
0
0
0
1
1
1
0
0
1
1
0
0
0
0
0
0
0
0
1
0
1
1
1
0
1
0
0
1
1
0
0
1
0
1
1
1
1
0
0
0
1
0
0
1
1
1
1
1
0
1
1
1
1
0
1
1
1
0
1
0
0
1
0
0
1
0
1
1
0
0
0
0
0
1
1
0
1
0
1
1
1
0
1
0
1
1
0
1
0
1
1
0
1
1
0
100 changes: 100 additions & 0 deletions EEG-BCI-Project/data_bci/sp1s_aa_test_1000Hz.txt

Large diffs are not rendered by default.

316 changes: 316 additions & 0 deletions EEG-BCI-Project/data_bci/sp1s_aa_train_1000Hz.txt

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions EEG-BCI-Project/final_500t.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[["Linear baseline", " train ", 0.0917721539735794], ["Linear baseline", " test ", 0.3499999940395355], ["Convolutional baseline", " train ", 0.2405063360929489], ["Convolutional baseline", " test ", 0.27000001072883606], ["EEG Net (2018)", " train ", 0.06962025165557861], ["EEG Net (2018)", " test ", 0.1899999976158142], ["ShallowConvNet", " train ", 0.09810126572847366], ["ShallowConvNet", " test ", 0.4099999964237213]]
Binary file added EEG-BCI-Project/loss_baselines_500.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added EEG-BCI-Project/loss_deeper_500.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions EEG-BCI-Project/plot_CV_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
plt.xlabel('Epoch')
plt.ylabel('Loss (Cross entropy)')
for i in range(0,len(lines)-2*n_last,2):
print(lines[i])
model = lines[i][0]
tr_loss = lines[i][2][2:]
val_loss = lines[i+1][2][2:]
Expand Down
9 changes: 7 additions & 2 deletions EEG-BCI-Project/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@

"""Load the data and standardize"""
import dlc_bci as bci
import torch

dtype = torch.float
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Loading the data and standardizing...")
train_input ,train_target = bci.load( root = './data_bci',one_khz=one_khz)
test_input,test_target = bci.load(root='./data_bci', train = False, one_khz=one_khz)
Expand Down Expand Up @@ -59,6 +63,7 @@

import cross_validation as cv #Also used for the train_model, even without cross validation_data
import json

dump_final = []
if(cross_val):
dump = []
Expand All @@ -84,8 +89,8 @@
final_te_error = cv.evaluate_error(model,test_input,test_target)
print("Train error = {} ; Test error = {} ".format(final_tr_error,final_te_error))

dump_final.append((model.name(), " train ", final_tr_error))
dump_final.append((model.name(), " test " , final_te_error))
dump_final.append((model.name(), " train ", final_tr_error.item()))
dump_final.append((model.name(), " test " , final_te_error.item()))


file = open('final_'+str(n_time_points)+'t.txt','w+')
Expand Down

0 comments on commit 5a4c149

Please sign in to comment.