Skip to content

Commit

Permalink
Offline inference for the McGill people
Browse files Browse the repository at this point in the history
  • Loading branch information
yannbouteiller committed Nov 27, 2021
1 parent 19ee206 commit 012ef6a
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 12 deletions.
7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,15 @@ pip install -e .
- unzip the `datasets.zip` file and paste its content under `Portiloop>portiloop_software>dataset`
- unzip the `experiments.zip` file and paste its content under `Portiloop>portiloop_software>experiments`

### Offline inference / simulation:
### Simulation:
The `simulate_Portiloop_1_input_classification.ipynb` [notebook](https://github.com/nicolasvalenchon/Portiloop/blob/main/notebooks/simulate_Portiloop_1_input_classification.ipynb) enables stimulating the Portiloop system and perform inference.
This notebook can be executed with `jupyter notebook`.

### Offline inference:
We enable easily using out trained artificial neural network on your own data to detect sleep spindles (note that the data must be collected in the same experimental setting as MODA for this to work, see [our paper](https://arxiv.org/abs/2107.13473)).

This is easily done by writing your signal in a simple text file, on the model of example_data_not_annotated.txt. Your file can then be directly used for inference in our offline_inference notebook.

### Training:
Functions used for training are defined in python under the `Software` folder.
We provide [bash scripts examples](https://github.com/nicolasvalenchon/Portiloop/releases/download/v0.0.1/scripts.zip) for `SLURM` to train the model on HPC systems.
Expand Down
1 change: 1 addition & 0 deletions portiloop_software/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from portiloop_software.portiloop_python.ANN.portiloop_detector_training import run_offline_unlabelled, get_final_model_config_dict
172 changes: 161 additions & 11 deletions portiloop_software/portiloop_python/ANN/portiloop_detector_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,8 @@ def __getitem__(self, idx):
idx = self.indices[idx]
assert self.data[3][idx + self.window_size - 1] >= 0, f"Bad index: {idx}."

signal_seq = self.full_signal[idx - (self.past_signal_len - self.idx_stride):idx + self.window_size].unfold(0, self.window_size,
self.idx_stride)
envelope_seq = self.full_envelope[idx - (self.past_signal_len - self.idx_stride):idx + self.window_size].unfold(0, self.window_size,
self.idx_stride)
signal_seq = self.full_signal[idx - (self.past_signal_len - self.idx_stride):idx + self.window_size].unfold(0, self.window_size, self.idx_stride)
envelope_seq = self.full_envelope[idx - (self.past_signal_len - self.idx_stride):idx + self.window_size].unfold(0, self.window_size, self.idx_stride)

ratio_pf = torch.tensor(self.data[2][idx + self.window_size - 1], dtype=torch.float)
label = torch.tensor(self.data[3][idx + self.window_size - 1], dtype=torch.float)
Expand All @@ -94,6 +92,34 @@ def is_spindle(self, idx):
return True if (self.data[3][idx + self.window_size - 1] > THRESHOLD) else False


class UnlabelledSignalDatasetSingleSegment(Dataset):
"""
Caution: this dataset does not sample sequences, but single windows
"""
def __init__(self, unlabelled_segment, window_size):
self.window_size = window_size
self.full_signal = torch.tensor(unlabelled_segment, dtype=torch.float).squeeze()
assert len(self.full_signal.shape) == 1, f"Segment has more than one dimension: {self.full_signal.shape}"
assert self.window_size <= len(self.full_signal), "Segment smaller than window size."
self.seq_len = 1 # 1 means single sample / no sequence ?
self.idx_stride = 1
self.past_signal_len = self.seq_len * self.idx_stride

# list of indices that can be sampled:
self.indices = [idx for idx in range(len(self.full_signal) - self.window_size) # all possible idxs in the dataset
if (not idx < self.past_signal_len)] # far enough from the beginning to build a sequence up to here

def __len__(self):
return len(self.indices)

def __getitem__(self, idx):
assert 0 <= idx < len(self), f"Index out of range ({idx}/{len(self)})."
idx = self.indices[idx]
signal_seq = self.full_signal[idx:idx + self.window_size].unfold(0, self.window_size, 1)
true_idx = idx + self.window_size
return signal_seq, true_idx


def get_class_idxs(dataset, distribution_mode):
"""
Directly outputs idx_true and idx_false arrays
Expand Down Expand Up @@ -174,7 +200,6 @@ def __len__(self):

class ValidationSampler(Sampler):
"""
__iter__ stops after an arbitrary number of iterations = batch_size_list * nb_batch
network_stride (int >= 1, default: 1): divides the size of the dataset (and of the batch) by striding further than 1
"""

Expand All @@ -195,6 +220,7 @@ def __iter__(self):
for i in range(self.nb_segment):
for j in range(0, (self.seq_stride // self.network_stride) * self.network_stride, self.network_stride):
cur_idx = i * self.len_segment + j + cursor_batch * self.seq_stride
# print(f"i:{i}, j:{j}, self.len_segment:{self.len_segment}, cursor_batch:{cursor_batch}, self.seq_stride:{self.seq_stride}, cur_idx:{cur_idx}")
yield cur_idx
cursor_batch += 1

Expand Down Expand Up @@ -568,6 +594,30 @@ def run_inference(dataloader, criterion, net, device, hidden_size, nb_rnn_layers
fn = (batch_labels_total * (1 - output_total))
return output_total, batch_labels_total, loss, acc, tp, tn, fp, fn

def run_inference_unlabelled_offline(dataloader, net, device, hidden_size, nb_rnn_layers, classification, batch_size_validation):
net_copy = copy.deepcopy(net)
net_copy = net_copy.to(device)
net_copy = net_copy.eval()
true_idx_total = torch.tensor([], device=device)
output_total = torch.tensor([], device=device)
h1 = torch.zeros((nb_rnn_layers, batch_size_validation, hidden_size), device=device)
h2 = torch.zeros((nb_rnn_layers, batch_size_validation, hidden_size), device=device)
max_value = np.inf
with torch.no_grad():
for batch_data in dataloader:
batch_samples_input1, batch_true_idx = batch_data
batch_samples_input1 = batch_samples_input1.to(device=device).float()
output, h1, h2, max_value = net_copy(batch_samples_input1, None, None, h1, h2, max_value)
output = output.view(-1)
# if not classification:
# output = output # (output > THRESHOLD)
# else:
# output = (output >= 0.5)
true_idx_total = torch.cat([true_idx_total, batch_true_idx])
output_total = torch.cat([output_total, output])
output_total = output_total.float()
true_idx_total = true_idx_total.int()
return output_total, true_idx_total

def get_metrics(tp, fp, fn):
tp_sum = tp.sum().to(torch.float32).item()
Expand Down Expand Up @@ -777,7 +827,7 @@ def generate_dataloader(window_size, fe, seq_len, seq_stride, distribution_mode,
if TEST_SET:
logging.debug(f"Subjects in test : {test_subject[:, 0]}")

len_segment_s = LEN_SEGMENT * fe
len_segment = LEN_SEGMENT * fe
train_loader = None
validation_loader = None
test_loader = None
Expand All @@ -796,7 +846,7 @@ def generate_dataloader(window_size, fe, seq_len, seq_stride, distribution_mode,
seq_len=seq_len,
seq_stride=seq_stride,
list_subject=train_subject,
len_segment=len_segment_s)
len_segment=len_segment)

ds_validation = SignalDataset(filename=filename,
path=path_dataset,
Expand All @@ -805,7 +855,7 @@ def generate_dataloader(window_size, fe, seq_len, seq_stride, distribution_mode,
seq_len=1,
seq_stride=1, # just to be sure, fixed value
list_subject=validation_subject,
len_segment=len_segment_s)
len_segment=len_segment)
idx_true, idx_false = get_class_idxs(ds_train, distribution_mode)
samp_train = RandomSampler(idx_true=idx_true,
idx_false=idx_false,
Expand All @@ -815,7 +865,7 @@ def generate_dataloader(window_size, fe, seq_len, seq_stride, distribution_mode,

samp_validation = ValidationSampler(ds_validation,
seq_stride=seq_stride,
len_segment=len_segment_s,
len_segment=len_segment,
nb_segment=nb_segment_validation,
network_stride=network_stride)
train_loader = DataLoader(ds_train,
Expand All @@ -842,11 +892,11 @@ def generate_dataloader(window_size, fe, seq_len, seq_stride, distribution_mode,
seq_len=1,
seq_stride=1, # just to be sure, fixed value
list_subject=test_subject,
len_segment=len_segment_s)
len_segment=len_segment)

samp_test = ValidationSampler(ds_test,
seq_stride=seq_stride,
len_segment=len_segment_s,
len_segment=len_segment,
nb_segment=nb_segment_test,
network_stride=network_stride)

Expand All @@ -859,6 +909,31 @@ def generate_dataloader(window_size, fe, seq_len, seq_stride, distribution_mode,

return train_loader, validation_loader, batch_size_validation, test_loader, batch_size_test, test_subject

def generate_dataloader_unlabelled_offline(unlabelled_segment,
window_size,
seq_stride,
network_stride):
nb_segment_test = 1
batch_size_test = len(list(range(0, (seq_stride // network_stride) * network_stride, network_stride))) * nb_segment_test
unlabelled_segment = torch.tensor(unlabelled_segment, dtype=torch.float).squeeze()
assert len(unlabelled_segment.shape) == 1, f"Segment has more than one dimension: {unlabelled_segment.shape}"
len_segment = len(unlabelled_segment)
ds_test = UnlabelledSignalDatasetSingleSegment(unlabelled_segment=unlabelled_segment,
window_size=window_size)
samp_test = ValidationSampler(ds_test,
seq_stride=seq_stride,
len_segment=len_segment-window_size, # because we don't have additional data at the end on the signal here
nb_segment=nb_segment_test,
network_stride=network_stride)

test_loader = DataLoader(ds_test,
batch_size=batch_size_test,
sampler=samp_test,
num_workers=0,
pin_memory=True,
shuffle=False)

return test_loader, batch_size_test

def run(config_dict, wandb_project, save_model, unique_name):
global precision_validation_factor
Expand Down Expand Up @@ -1129,6 +1204,52 @@ def run(config_dict, wandb_project, save_model, unique_name):
logging.debug("Logger deleted")
return best_model_loss_validation, best_model_f1_score_validation, best_epoch_early_stopping

def run_offline_unlabelled(config_dict, path_experiments, unlabelled_segment):
logging.debug(f"config_dict: {config_dict}")
experiment_name = config_dict['experiment_name']
window_size_s = config_dict["window_size_s"]
fe = config_dict["fe"]
seq_stride_s = config_dict["seq_stride_s"]
hidden_size = config_dict["hidden_size"]
device_inference = config_dict["device_inference"]
nb_rnn_layers = config_dict["nb_rnn_layers"]
classification = config_dict["classification"]
validation_network_stride = config_dict["validation_network_stride"]

window_size = int(window_size_s * fe)
seq_stride = int(seq_stride_s * fe)

if device_inference.startswith("cuda"):
assert torch.cuda.is_available(), "CUDA unavailable"

torch.seed()
net = PortiloopNetwork(config_dict).to(device=device_inference)

file_exp = experiment_name
file_exp += "" if classification else "_on_loss"
path_experiments = Path(path_experiments)
if not device_inference.startswith("cuda"):
checkpoint = torch.load(path_experiments / file_exp, map_location=torch.device('cpu'))
else:
checkpoint = torch.load(path_experiments / file_exp)
logging.debug("Use checkpoint model")
net.load_state_dict(checkpoint['model_state_dict'])

test_loader, batch_size_test = generate_dataloader_unlabelled_offline(unlabelled_segment=unlabelled_segment,
window_size=window_size,
seq_stride=seq_stride,
network_stride=validation_network_stride)

output_total, true_idx_total = run_inference_unlabelled_offline(dataloader=test_loader,
net=net,
device=device_inference,
hidden_size=hidden_size,
nb_rnn_layers=nb_rnn_layers,
classification=classification,
batch_size_validation=batch_size_test)
return output_total, true_idx_total



def get_config_dict(index, split_i):
# config_dict = {'experiment_name': f'pareto_search_10_619_{index}', 'device_train': 'cuda:0', 'device_val': 'cuda:0', 'nb_epoch_max': 1000,
Expand Down Expand Up @@ -1209,6 +1330,35 @@ def get_config_dict(index, split_i):
'estimator_size_memory': 188006400}
return c_dict

def get_final_model_config_dict(index=0, split_i=0):
"""
Configuration dictionary of the final 1-input pre-trained model presented in the Portiloop paper.
Args:
index: last number in the name of the pre-trained model (several are provided)
split_i: index of the random train/validation/test split (you can ignore this for inference)
Returns:
configuration dictionary of the pre-trained model
"""
c_dict = {'experiment_name': f'pareto_search_15_35_v4_{index}', 'device_train': 'cpu', 'device_val': 'cpu',
'device_inference': 'cpu', 'nb_epoch_max': 150, 'max_duration': 257400,
'nb_epoch_early_stopping_stop': 100, 'early_stopping_smoothing_factor': 0.1, 'fe': 250,
'nb_batch_per_epoch': 1000,
'first_layer_dropout': False,
'power_features_input': False, 'dropout': 0.5, 'adam_w': 0.01, 'distribution_mode': 0,
'classification': True,
'reg_balancing': 'none',
'split_idx': split_i, 'validation_network_stride': 1, 'nb_conv_layers': 3, 'seq_len': 50,
'nb_channel': 31, 'hidden_size': 7,
'seq_stride_s': 0.170,
'nb_rnn_layers': 1, 'RNN': True, 'envelope_input': False, 'lr_adam': 0.0005, 'batch_size': 256,
'window_size_s': 0.218,
'stride_pool': 1,
'stride_conv': 1, 'kernel_conv': 7, 'kernel_pool': 7, 'dilation_conv': 1, 'dilation_pool': 1,
'nb_out': 18, 'time_in_past': 8.5,
'estimator_size_memory': 188006400}
return c_dict

if __name__ == "__main__":
parser = ArgumentParser()
Expand Down

0 comments on commit 012ef6a

Please sign in to comment.