-
Notifications
You must be signed in to change notification settings - Fork 42
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
variables renaming for consistency in prep.py + tests #197
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,7 +18,7 @@ | |
import mne | ||
from autoreject import get_rejection_threshold, AutoReject, RejectLog | ||
from mne.preprocessing import ICA, corrmap | ||
from typing import List, Tuple, TypedDict | ||
from typing import List, Tuple, TypedDict, Union | ||
|
||
class DicAR(TypedDict): | ||
""" | ||
|
@@ -31,7 +31,10 @@ class DicAR(TypedDict): | |
dyad: float | ||
|
||
|
||
def filt(raw_S: List[mne.io.Raw]) -> List[mne.io.Raw]: | ||
def filt( | ||
raw_S: List[mne.io.Raw], | ||
freqs: Tuple[Union[float, None], Union[float, None]] = (2., None) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added the "freqs" argument to complete the TODO in this function |
||
) -> List[mne.io.Raw]: | ||
""" | ||
Filters list of raw data to remove slow drifts. | ||
|
||
|
@@ -42,11 +45,8 @@ def filt(raw_S: List[mne.io.Raw]) -> List[mne.io.Raw]: | |
Returns: | ||
raws: list of high-pass filtered raws. | ||
""" | ||
# TODO: l_freq and h_freq as param | ||
raws: List[mne.io.Raw] = [] | ||
for raw in raw_S: | ||
raws.append(mne.io.Raw.filter(raw, l_freq=2., h_freq=None)) | ||
|
||
|
||
raws = [mne.io.Raw.filter(raw, l_freq=freqs[0], h_freq=freqs[1]) for raw in raw_S] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. use list comprehension in multiple places to take advantage of type inference |
||
return raws | ||
|
||
|
||
|
@@ -77,68 +77,71 @@ def ICA_choice_comp(icas: List[ICA], epochs: List[mne.Epochs]) -> List[mne.Epoch | |
|
||
# choosing participant and its component as a template for the other participant | ||
# if do not want to apply ICA on the data, do not fill the answer | ||
subj_numb = input("Which participant ICA do you want" | ||
" to use as a template for artifact rejection?" | ||
" Index begins at zero. (If you do not want to apply" | ||
" ICA on your data, do not enter nothing and press enter.)") | ||
comp_number = input("Which IC do you want to use as a template?" | ||
subject_id = input("Which participant ICA do you want" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. variable renaming for consistency |
||
" to use as a template for artifact rejection?" | ||
" Index begins at zero. (If you do not want to apply" | ||
" ICA on your data, do not enter nothing and press enter.)") | ||
|
||
component_id = input("Which IC do you want to use as a template?" | ||
" Index begins at zero. (If you did not choose" | ||
"a participant number at first question," | ||
"then do not enter nothing and press enter again" | ||
"to not apply ICA on your data)") | ||
|
||
# applying ICA | ||
if (len(subj_numb) != 0 and len(comp_number) != 0): | ||
cleaned_epochs_ICA = ICA_apply(icas, | ||
int(subj_numb), | ||
int(comp_number), | ||
epochs) | ||
else: | ||
cleaned_epochs_ICA = epochs | ||
" a participant number at first question," | ||
" then do not enter nothing and press enter again" | ||
" to not apply ICA on your data)") | ||
|
||
return cleaned_epochs_ICA | ||
if (len(subject_id) == 0 or len(component_id) == 0): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Invert code flow to simplify code (return early to reduce nested code) |
||
return epochs | ||
|
||
return ICA_apply(icas, int(subject_id), int(component_id), epochs) | ||
|
||
|
||
def ICA_apply(icas: List[ICA], subj_number: int, comp_number: int, epochs: List[mne.Epochs]) -> List[mne.Epochs]: | ||
def ICA_apply(icas: List[ICA], subject_id: int, component_id: int, epochs: List[mne.Epochs], plot: bool = True) -> List[mne.Epochs]: | ||
""" | ||
Applies ICA with template model from 1 participant in the dyad. | ||
See ICA_choice_comp for a detailed description of the parameters and output. | ||
""" | ||
|
||
cleaned_epochs_ICA: List[ICA] = [] | ||
|
||
# selecting which ICs corresponding to the template | ||
template_eog_component = icas[subj_number].get_components()[:, comp_number] | ||
template_eog_component = icas[subject_id].get_components()[:, component_id] | ||
|
||
# applying corrmap with at least 1 component detected for each subj | ||
fig_template, fig_detected = corrmap(icas, | ||
template=template_eog_component, | ||
threshold=0.9, | ||
label='blink', | ||
ch_type='eeg') | ||
corrmap(icas, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added |
||
template=template_eog_component, | ||
threshold=0.9, | ||
label='blink', | ||
ch_type='eeg', | ||
plot=plot, | ||
) | ||
|
||
# labeling the ICs that capture blink artifacts | ||
print([ica.labels_ for ica in icas]) | ||
|
||
# selecting ICA components after viz | ||
for i in icas: | ||
i.exclude = i.labels_['blink'] | ||
for ica in icas: | ||
ica.exclude = ica.labels_['blink'] | ||
|
||
epoch_all_ch = [] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unnecessary array, we don't use it afterwards. We just use a variable in the loop scope below for the epochs of a subject, then add it to the returned list |
||
# applying ica on clean_epochs | ||
# for each participant | ||
for i, j in zip(range(0, len(epochs)), icas): | ||
for subject_id, ica in zip(range(0, len(epochs)), icas): | ||
# taking all channels to apply ICA | ||
bads = epochs[i].info['bads'] | ||
epoch_all_ch.append(mne.Epochs.copy(epochs[i])) | ||
epoch_all_ch[i].info['bads'] = [] | ||
j.apply(epoch_all_ch[i]) | ||
epoch_all_ch[i].info['bads'] = bads | ||
cleaned_epochs_ICA.append(epoch_all_ch[i]) | ||
epochs_subj = mne.Epochs.copy(epochs[subject_id]) | ||
bads_keep = epochs_subj.info['bads'] | ||
epochs_subj.info['bads'] = [] | ||
ica.apply(epochs_subj) | ||
epochs_subj.info['bads'] = bads_keep | ||
cleaned_epochs_ICA.append(epochs_subj) | ||
|
||
return cleaned_epochs_ICA | ||
|
||
|
||
def ICA_fit(epochs: List[mne.Epochs], n_components: int, method: str, fit_params: dict, random_state: int) -> List[ICA]: | ||
def ICA_fit( | ||
epochs: List[mne.Epochs], | ||
n_components: int, | ||
method: str, | ||
fit_params: dict, | ||
random_state: int | ||
) -> List[ICA]: | ||
""" | ||
Computes global Autorejection to fit Independent Components Analysis | ||
on Epochs, for each participant. | ||
|
@@ -178,23 +181,24 @@ def ICA_fit(epochs: List[mne.Epochs], n_components: int, method: str, fit_params | |
objects, see MNE documentation for more details). | ||
""" | ||
icas: List[ICA] = [] | ||
for epoch in epochs: | ||
for epochs_subj in epochs: | ||
# per subj | ||
# applying AR to find global rejection threshold | ||
reject = get_rejection_threshold(epoch, ch_types='eeg') | ||
reject = get_rejection_threshold(epochs_subj, ch_types='eeg') | ||
# if very long, can change decim value | ||
print('The rejection dictionary is %s' % reject) | ||
print(f"The rejection dictionary is {reject}") | ||
|
||
# fitting ICA on filt_raw after AR | ||
ica = ICA(n_components=n_components, | ||
method=method, | ||
fit_params= fit_params, | ||
random_state=random_state).fit(epoch) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I might be wrong, but I think calling |
||
random_state=random_state) | ||
|
||
# take bad channels into account in ICA fit | ||
epoch_all_ch = mne.Epochs.copy(epoch) | ||
epoch_all_ch.info['bads'] = [] | ||
epoch_all_ch.drop_bad(reject=reject, flat=None) | ||
icas.append(ica.fit(epoch_all_ch)) | ||
epochs_fit = mne.Epochs.copy(epochs_subj) | ||
epochs_fit.info['bads'] = [] | ||
epochs_fit.drop_bad(reject=reject, flat=None) | ||
icas.append(ica.fit(epochs_fit)) | ||
|
||
return icas | ||
|
||
|
@@ -227,7 +231,7 @@ def AR_local(cleaned_epochs_ICA: List[mne.Epochs], strategy: str = 'union', thre | |
for each subject and for the intersection of the them. | ||
""" | ||
|
||
bad_epochs_AR: List[RejectLog] = [] | ||
reject_logs: List[RejectLog] = [] | ||
AR: List[AutoReject] = [] | ||
dic_AR: DicAR = {} | ||
dic_AR['strategy'] = strategy | ||
|
@@ -241,89 +245,91 @@ def AR_local(cleaned_epochs_ICA: List[mne.Epochs], strategy: str = 'union', thre | |
# n_interpolates = np.array([1, 4, 8, 16, 32, 64]) | ||
# consensus_percs = np.linspace(0.5, 1.0, 11) | ||
|
||
for clean_epochs in cleaned_epochs_ICA: # per subj | ||
|
||
for subject_id, clean_epochs_subj in enumerate(cleaned_epochs_ICA): # per subj | ||
picks = mne.pick_types( | ||
clean_epochs[0].info, | ||
clean_epochs_subj[subject_id].info, | ||
meg=False, | ||
eeg=True, | ||
stim=False, | ||
eog=False, | ||
exclude=[]) | ||
|
||
ar = AutoReject(n_interpolates, consensus_percs, picks=picks, | ||
thresh_method='random_search', random_state=42, | ||
verbose='tqdm_notebook') | ||
AR.append(ar) | ||
ar = AutoReject(n_interpolates, | ||
consensus_percs, | ||
picks=picks, | ||
thresh_method='random_search', | ||
random_state=42, | ||
verbose='tqdm_notebook') | ||
|
||
# fitting AR to get bad epochs | ||
ar.fit(clean_epochs) | ||
reject_log = ar.get_reject_log(clean_epochs, picks=picks) | ||
bad_epochs_AR.append(reject_log) | ||
ar.fit(clean_epochs_subj) | ||
reject_log = ar.get_reject_log(clean_epochs_subj, picks=picks) | ||
|
||
# taking bad epochs for min 1 subj (dyad) | ||
log1 = bad_epochs_AR[0] | ||
log2 = bad_epochs_AR[1] | ||
AR.append(ar) | ||
reject_logs.append(reject_log) | ||
|
||
bad1 = np.where(log1.bad_epochs == True) | ||
bad2 = np.where(log2.bad_epochs == True) | ||
# taking bad epochs for min 1 subj (dyad) | ||
bad1_idx = np.where(reject_logs[0].bad_epochs == True)[0].tolist() | ||
bad2_idx = np.where(reject_logs[1].bad_epochs == True)[0].tolist() | ||
|
||
if strategy == 'union': | ||
bad = list(set(bad1[0].tolist()).union(set(bad2[0].tolist()))) | ||
bad_idx = list(set(bad1_idx).union(set(bad2_idx))) | ||
elif strategy == 'intersection': | ||
bad = list(set(bad1[0].tolist()).intersection(set(bad2[0].tolist()))) | ||
bad_idx = list(set(bad1_idx).intersection(set(bad2_idx))) | ||
else: | ||
TypeError('not good strategy input!') | ||
raise RuntimeError('not good strategy input!') | ||
|
||
# storing the percentage of epochs rejection | ||
dic_AR['S1'] = float((len(bad1[0].tolist())/len(cleaned_epochs_ICA[0]))*100) | ||
dic_AR['S2'] = float((len(bad2[0].tolist())/len(cleaned_epochs_ICA[1]))*100) | ||
dic_AR['S1'] = float((len(bad1_idx) / len(cleaned_epochs_ICA[0])) * 100) | ||
dic_AR['S2'] = float((len(bad2_idx) / len(cleaned_epochs_ICA[1])) * 100) | ||
|
||
# picking good epochs for the two subj | ||
cleaned_epochs_AR: List[mne.Epochs] = [] | ||
for clean_epochs in cleaned_epochs_ICA: # per subj | ||
|
||
for subject_id, clean_epochs_subj in enumerate(cleaned_epochs_ICA): # per subj | ||
# keep a copy of the original data | ||
clean_epochs_ep = copy.deepcopy(clean_epochs) | ||
clean_epochs_ep = clean_epochs_ep.drop(indices=bad) | ||
epochs_subj = copy.deepcopy(clean_epochs_subj) | ||
epochs_subj.drop(indices=bad_idx) | ||
# interpolating bads or removing epochs | ||
ar = AR[cleaned_epochs_ICA.index(clean_epochs)] | ||
clean_epochs_AR = ar.transform(clean_epochs_ep) | ||
cleaned_epochs_AR.append(clean_epochs_AR) | ||
ar = AR[subject_id] | ||
epochs_AR = ar.transform(epochs_subj) | ||
cleaned_epochs_AR.append(epochs_AR) | ||
|
||
if strategy == 'intersection': | ||
# equalizing epochs length between two participants | ||
mne.epochs.equalize_epoch_counts(cleaned_epochs_AR) | ||
|
||
dic_AR['dyad'] = float(((len(cleaned_epochs_ICA[0])-len(cleaned_epochs_AR[0]))/len(cleaned_epochs_ICA[0]))*100) | ||
n_epochs_ICA = len(cleaned_epochs_ICA[0]) | ||
n_epochs_AR = len(cleaned_epochs_AR[0]) | ||
|
||
dic_AR['dyad'] = float(((n_epochs_ICA - n_epochs_AR) / n_epochs_ICA) * 100) | ||
if dic_AR['dyad'] >= threshold: | ||
TypeError('percentage of rejected epochs above threshold!') | ||
raise RuntimeError(f"percentage of rejected epochs ({dic_AR['dyad']}) above threshold ({threshold})! ") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a bug here: Errors need to be raised to interrupt the execution flow. The test on this condition was inadequate (also fixed in this commit). Also the type |
||
if verbose: | ||
print('%s percent of bad epochs' % dic_AR['dyad']) | ||
print(f"{dic_AR['dyad']} percent of bad epochs") | ||
|
||
# Vizualisation before after AR | ||
evoked_before: List[mne.Evoked] = [] | ||
for clean_epochs in cleaned_epochs_ICA: # per subj | ||
evoked_before.append(clean_epochs.average()) | ||
|
||
evoked_after_AR: List[mne.Evoked] = [] | ||
for clean in cleaned_epochs_AR: | ||
evoked_after_AR.append(clean.average()) | ||
evoked_before: List[mne.Evoked] = [epochs.average() for epochs in cleaned_epochs_ICA] | ||
evoked_after_AR: List[mne.Evoked] = [epochs.average() for epochs in cleaned_epochs_AR] | ||
|
||
if verbose: | ||
for i, j in zip(evoked_before, evoked_after_AR): | ||
for evoked_before_subj, evoked_after_AR_subj in zip(evoked_before, evoked_after_AR): | ||
fig, axes = plt.subplots(2, 1, figsize=(6, 6)) | ||
for ax in axes: | ||
ax.tick_params(axis='x', which='both', bottom='off', top='off') | ||
ax.tick_params(axis='y', which='both', left='off', right='off') | ||
|
||
ylim = dict(grad=(-170, 200)) | ||
i.pick_types(eeg=True, exclude=[]) | ||
i.plot(exclude=[], axes=axes[0], ylim=ylim, show=False) | ||
|
||
evoked_before_subj.pick_types(eeg=True, exclude=[]) | ||
evoked_before_subj.plot(exclude=[], axes=axes[0], ylim=ylim, show=False) | ||
axes[0].set_title('Before autoreject') | ||
j.pick_types(eeg=True, exclude=[]) | ||
j.plot(exclude=[], axes=axes[1], ylim=ylim) | ||
|
||
evoked_after_AR_subj.pick_types(eeg=True, exclude=[]) | ||
evoked_after_AR_subj.plot(exclude=[], axes=axes[1], ylim=ylim) | ||
# Problème titre ne s'affiche pas pour le deuxieme axe !!! | ||
axes[1].set_title('After autoreject') | ||
|
||
plt.tight_layout() | ||
|
||
return cleaned_epochs_AR, dic_AR |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
exceptions must be raised in order to interrupt the execution flow