Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

variables renaming for consistency in prep.py + tests #197

Merged
merged 3 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions hypyp/analyses.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ def pair_connectivity(data: Union[list, np.ndarray], sampling_rate: int, frequen
elif type(frequencies) == dict:
values = compute_freq_bands(data, sampling_rate, frequencies)
else:
TypeError("Please use a list or a dictionary to specify frequencies.")
raise TypeError("Please use a list or a dictionary to specify frequencies.")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

exceptions must be raised in order to interrupt the execution flow


# compute connectivity values
result = compute_sync(values, mode, epochs_average)
Expand Down Expand Up @@ -511,7 +511,7 @@ def compute_sync(complex_signal: np.ndarray, mode: str, epochs_average: bool = T
con = con_num / con_den

else:
ValueError('Metric type not supported.')
raise ValueError('Metric type not supported.')

con = con.swapaxes(0, 1) # n_freq x n_epoch x 2*n_ch x 2*n_ch
if epochs_average:
Expand Down
192 changes: 99 additions & 93 deletions hypyp/prep.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import mne
from autoreject import get_rejection_threshold, AutoReject, RejectLog
from mne.preprocessing import ICA, corrmap
from typing import List, Tuple, TypedDict
from typing import List, Tuple, TypedDict, Union

class DicAR(TypedDict):
"""
Expand All @@ -31,7 +31,10 @@ class DicAR(TypedDict):
dyad: float


def filt(raw_S: List[mne.io.Raw]) -> List[mne.io.Raw]:
def filt(
raw_S: List[mne.io.Raw],
freqs: Tuple[Union[float, None], Union[float, None]] = (2., None)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added the "freqs" argument to complete the TODO in this function

) -> List[mne.io.Raw]:
"""
Filters list of raw data to remove slow drifts.

Expand All @@ -42,11 +45,8 @@ def filt(raw_S: List[mne.io.Raw]) -> List[mne.io.Raw]:
Returns:
raws: list of high-pass filtered raws.
"""
# TODO: l_freq and h_freq as param
raws: List[mne.io.Raw] = []
for raw in raw_S:
raws.append(mne.io.Raw.filter(raw, l_freq=2., h_freq=None))


raws = [mne.io.Raw.filter(raw, l_freq=freqs[0], h_freq=freqs[1]) for raw in raw_S]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use list comprehension in multiple places to take advantage of type inference

return raws


Expand Down Expand Up @@ -77,68 +77,71 @@ def ICA_choice_comp(icas: List[ICA], epochs: List[mne.Epochs]) -> List[mne.Epoch

# choosing participant and its component as a template for the other participant
# if do not want to apply ICA on the data, do not fill the answer
subj_numb = input("Which participant ICA do you want"
" to use as a template for artifact rejection?"
" Index begins at zero. (If you do not want to apply"
" ICA on your data, do not enter nothing and press enter.)")
comp_number = input("Which IC do you want to use as a template?"
subject_id = input("Which participant ICA do you want"
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

variable renaming for consistency

" to use as a template for artifact rejection?"
" Index begins at zero. (If you do not want to apply"
" ICA on your data, do not enter nothing and press enter.)")

component_id = input("Which IC do you want to use as a template?"
" Index begins at zero. (If you did not choose"
"a participant number at first question,"
"then do not enter nothing and press enter again"
"to not apply ICA on your data)")

# applying ICA
if (len(subj_numb) != 0 and len(comp_number) != 0):
cleaned_epochs_ICA = ICA_apply(icas,
int(subj_numb),
int(comp_number),
epochs)
else:
cleaned_epochs_ICA = epochs
" a participant number at first question,"
" then do not enter nothing and press enter again"
" to not apply ICA on your data)")

return cleaned_epochs_ICA
if (len(subject_id) == 0 or len(component_id) == 0):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Invert code flow to simplify code (return early to reduce nested code)

return epochs

return ICA_apply(icas, int(subject_id), int(component_id), epochs)


def ICA_apply(icas: List[ICA], subj_number: int, comp_number: int, epochs: List[mne.Epochs]) -> List[mne.Epochs]:
def ICA_apply(icas: List[ICA], subject_id: int, component_id: int, epochs: List[mne.Epochs], plot: bool = True) -> List[mne.Epochs]:
"""
Applies ICA with template model from 1 participant in the dyad.
See ICA_choice_comp for a detailed description of the parameters and output.
"""

cleaned_epochs_ICA: List[ICA] = []

# selecting which ICs corresponding to the template
template_eog_component = icas[subj_number].get_components()[:, comp_number]
template_eog_component = icas[subject_id].get_components()[:, component_id]

# applying corrmap with at least 1 component detected for each subj
fig_template, fig_detected = corrmap(icas,
template=template_eog_component,
threshold=0.9,
label='blink',
ch_type='eeg')
corrmap(icas,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added plot as ICA_apply argument (to avoid plots launching GUI windows during tests). The presence/absence of plot argument here change the return type of corrmap, so we cannot always unpack fig_template and fig_detected. Anyway, they are not used.

template=template_eog_component,
threshold=0.9,
label='blink',
ch_type='eeg',
plot=plot,
)

# labeling the ICs that capture blink artifacts
print([ica.labels_ for ica in icas])

# selecting ICA components after viz
for i in icas:
i.exclude = i.labels_['blink']
for ica in icas:
ica.exclude = ica.labels_['blink']

epoch_all_ch = []
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unnecessary array, we don't use it afterwards. We just use a variable in the loop scope below for the epochs of a subject, then add it to the returned list

# applying ica on clean_epochs
# for each participant
for i, j in zip(range(0, len(epochs)), icas):
for subject_id, ica in zip(range(0, len(epochs)), icas):
# taking all channels to apply ICA
bads = epochs[i].info['bads']
epoch_all_ch.append(mne.Epochs.copy(epochs[i]))
epoch_all_ch[i].info['bads'] = []
j.apply(epoch_all_ch[i])
epoch_all_ch[i].info['bads'] = bads
cleaned_epochs_ICA.append(epoch_all_ch[i])
epochs_subj = mne.Epochs.copy(epochs[subject_id])
bads_keep = epochs_subj.info['bads']
epochs_subj.info['bads'] = []
ica.apply(epochs_subj)
epochs_subj.info['bads'] = bads_keep
cleaned_epochs_ICA.append(epochs_subj)

return cleaned_epochs_ICA


def ICA_fit(epochs: List[mne.Epochs], n_components: int, method: str, fit_params: dict, random_state: int) -> List[ICA]:
def ICA_fit(
epochs: List[mne.Epochs],
n_components: int,
method: str,
fit_params: dict,
random_state: int
) -> List[ICA]:
"""
Computes global Autorejection to fit Independent Components Analysis
on Epochs, for each participant.
Expand Down Expand Up @@ -178,23 +181,24 @@ def ICA_fit(epochs: List[mne.Epochs], n_components: int, method: str, fit_params
objects, see MNE documentation for more details).
"""
icas: List[ICA] = []
for epoch in epochs:
for epochs_subj in epochs:
# per subj
# applying AR to find global rejection threshold
reject = get_rejection_threshold(epoch, ch_types='eeg')
reject = get_rejection_threshold(epochs_subj, ch_types='eeg')
# if very long, can change decim value
print('The rejection dictionary is %s' % reject)
print(f"The rejection dictionary is {reject}")

# fitting ICA on filt_raw after AR
ica = ICA(n_components=n_components,
method=method,
fit_params= fit_params,
random_state=random_state).fit(epoch)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might be wrong, but I think calling .fit(epoch) here is useless, since we call .fit() below again.

random_state=random_state)

# take bad channels into account in ICA fit
epoch_all_ch = mne.Epochs.copy(epoch)
epoch_all_ch.info['bads'] = []
epoch_all_ch.drop_bad(reject=reject, flat=None)
icas.append(ica.fit(epoch_all_ch))
epochs_fit = mne.Epochs.copy(epochs_subj)
epochs_fit.info['bads'] = []
epochs_fit.drop_bad(reject=reject, flat=None)
icas.append(ica.fit(epochs_fit))

return icas

Expand Down Expand Up @@ -227,7 +231,7 @@ def AR_local(cleaned_epochs_ICA: List[mne.Epochs], strategy: str = 'union', thre
for each subject and for the intersection of the them.
"""

bad_epochs_AR: List[RejectLog] = []
reject_logs: List[RejectLog] = []
AR: List[AutoReject] = []
dic_AR: DicAR = {}
dic_AR['strategy'] = strategy
Expand All @@ -241,89 +245,91 @@ def AR_local(cleaned_epochs_ICA: List[mne.Epochs], strategy: str = 'union', thre
# n_interpolates = np.array([1, 4, 8, 16, 32, 64])
# consensus_percs = np.linspace(0.5, 1.0, 11)

for clean_epochs in cleaned_epochs_ICA: # per subj

for subject_id, clean_epochs_subj in enumerate(cleaned_epochs_ICA): # per subj
picks = mne.pick_types(
clean_epochs[0].info,
clean_epochs_subj[subject_id].info,
meg=False,
eeg=True,
stim=False,
eog=False,
exclude=[])

ar = AutoReject(n_interpolates, consensus_percs, picks=picks,
thresh_method='random_search', random_state=42,
verbose='tqdm_notebook')
AR.append(ar)
ar = AutoReject(n_interpolates,
consensus_percs,
picks=picks,
thresh_method='random_search',
random_state=42,
verbose='tqdm_notebook')

# fitting AR to get bad epochs
ar.fit(clean_epochs)
reject_log = ar.get_reject_log(clean_epochs, picks=picks)
bad_epochs_AR.append(reject_log)
ar.fit(clean_epochs_subj)
reject_log = ar.get_reject_log(clean_epochs_subj, picks=picks)

# taking bad epochs for min 1 subj (dyad)
log1 = bad_epochs_AR[0]
log2 = bad_epochs_AR[1]
AR.append(ar)
reject_logs.append(reject_log)

bad1 = np.where(log1.bad_epochs == True)
bad2 = np.where(log2.bad_epochs == True)
# taking bad epochs for min 1 subj (dyad)
bad1_idx = np.where(reject_logs[0].bad_epochs == True)[0].tolist()
bad2_idx = np.where(reject_logs[1].bad_epochs == True)[0].tolist()

if strategy == 'union':
bad = list(set(bad1[0].tolist()).union(set(bad2[0].tolist())))
bad_idx = list(set(bad1_idx).union(set(bad2_idx)))
elif strategy == 'intersection':
bad = list(set(bad1[0].tolist()).intersection(set(bad2[0].tolist())))
bad_idx = list(set(bad1_idx).intersection(set(bad2_idx)))
else:
TypeError('not good strategy input!')
raise RuntimeError('not good strategy input!')

# storing the percentage of epochs rejection
dic_AR['S1'] = float((len(bad1[0].tolist())/len(cleaned_epochs_ICA[0]))*100)
dic_AR['S2'] = float((len(bad2[0].tolist())/len(cleaned_epochs_ICA[1]))*100)
dic_AR['S1'] = float((len(bad1_idx) / len(cleaned_epochs_ICA[0])) * 100)
dic_AR['S2'] = float((len(bad2_idx) / len(cleaned_epochs_ICA[1])) * 100)

# picking good epochs for the two subj
cleaned_epochs_AR: List[mne.Epochs] = []
for clean_epochs in cleaned_epochs_ICA: # per subj

for subject_id, clean_epochs_subj in enumerate(cleaned_epochs_ICA): # per subj
# keep a copy of the original data
clean_epochs_ep = copy.deepcopy(clean_epochs)
clean_epochs_ep = clean_epochs_ep.drop(indices=bad)
epochs_subj = copy.deepcopy(clean_epochs_subj)
epochs_subj.drop(indices=bad_idx)
# interpolating bads or removing epochs
ar = AR[cleaned_epochs_ICA.index(clean_epochs)]
clean_epochs_AR = ar.transform(clean_epochs_ep)
cleaned_epochs_AR.append(clean_epochs_AR)
ar = AR[subject_id]
epochs_AR = ar.transform(epochs_subj)
cleaned_epochs_AR.append(epochs_AR)

if strategy == 'intersection':
# equalizing epochs length between two participants
mne.epochs.equalize_epoch_counts(cleaned_epochs_AR)

dic_AR['dyad'] = float(((len(cleaned_epochs_ICA[0])-len(cleaned_epochs_AR[0]))/len(cleaned_epochs_ICA[0]))*100)
n_epochs_ICA = len(cleaned_epochs_ICA[0])
n_epochs_AR = len(cleaned_epochs_AR[0])

dic_AR['dyad'] = float(((n_epochs_ICA - n_epochs_AR) / n_epochs_ICA) * 100)
if dic_AR['dyad'] >= threshold:
TypeError('percentage of rejected epochs above threshold!')
raise RuntimeError(f"percentage of rejected epochs ({dic_AR['dyad']}) above threshold ({threshold})! ")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There was a bug here: Errors need to be raised to interrupt the execution flow. The test on this condition was inadequate (also fixed in this commit).

Also the type RuntimeError seems more appropriate.

if verbose:
print('%s percent of bad epochs' % dic_AR['dyad'])
print(f"{dic_AR['dyad']} percent of bad epochs")

# Vizualisation before after AR
evoked_before: List[mne.Evoked] = []
for clean_epochs in cleaned_epochs_ICA: # per subj
evoked_before.append(clean_epochs.average())

evoked_after_AR: List[mne.Evoked] = []
for clean in cleaned_epochs_AR:
evoked_after_AR.append(clean.average())
evoked_before: List[mne.Evoked] = [epochs.average() for epochs in cleaned_epochs_ICA]
evoked_after_AR: List[mne.Evoked] = [epochs.average() for epochs in cleaned_epochs_AR]

if verbose:
for i, j in zip(evoked_before, evoked_after_AR):
for evoked_before_subj, evoked_after_AR_subj in zip(evoked_before, evoked_after_AR):
fig, axes = plt.subplots(2, 1, figsize=(6, 6))
for ax in axes:
ax.tick_params(axis='x', which='both', bottom='off', top='off')
ax.tick_params(axis='y', which='both', left='off', right='off')

ylim = dict(grad=(-170, 200))
i.pick_types(eeg=True, exclude=[])
i.plot(exclude=[], axes=axes[0], ylim=ylim, show=False)

evoked_before_subj.pick_types(eeg=True, exclude=[])
evoked_before_subj.plot(exclude=[], axes=axes[0], ylim=ylim, show=False)
axes[0].set_title('Before autoreject')
j.pick_types(eeg=True, exclude=[])
j.plot(exclude=[], axes=axes[1], ylim=ylim)

evoked_after_AR_subj.pick_types(eeg=True, exclude=[])
evoked_after_AR_subj.plot(exclude=[], axes=axes[1], ylim=ylim)
# Problème titre ne s'affiche pas pour le deuxieme axe !!!
axes[1].set_title('After autoreject')

plt.tight_layout()

return cleaned_epochs_AR, dic_AR
2 changes: 1 addition & 1 deletion hypyp/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -1735,7 +1735,7 @@ def plot_xwt(sig1: mne.Epochs, sig2: mne.Epochs,
plt.imshow(data, aspect='auto', cmap=my_cm, interpolation='lanczos')

else:
ValueError('Analysis must be set as phase, power, or wtc.')
raise ValueError('Analysis must be set as phase, power, or wtc.')

plt.gca().invert_yaxis()
plt.ylabel('Frequencies (Hz)')
Expand Down
Loading
Loading