Skip to content

Commit

Permalink
variables renaming for consistency in prep.py + tests
Browse files Browse the repository at this point in the history
  • Loading branch information
patricefortin committed Sep 30, 2024
1 parent 80e539e commit 52e9cd6
Show file tree
Hide file tree
Showing 5 changed files with 190 additions and 135 deletions.
4 changes: 2 additions & 2 deletions hypyp/analyses.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ def pair_connectivity(data: Union[list, np.ndarray], sampling_rate: int, frequen
elif type(frequencies) == dict:
values = compute_freq_bands(data, sampling_rate, frequencies)
else:
TypeError("Please use a list or a dictionary to specify frequencies.")
raise TypeError("Please use a list or a dictionary to specify frequencies.")

# compute connectivity values
result = compute_sync(values, mode, epochs_average)
Expand Down Expand Up @@ -511,7 +511,7 @@ def compute_sync(complex_signal: np.ndarray, mode: str, epochs_average: bool = T
con = con_num / con_den

else:
ValueError('Metric type not supported.')
raise ValueError('Metric type not supported.')

con = con.swapaxes(0, 1) # n_freq x n_epoch x 2*n_ch x 2*n_ch
if epochs_average:
Expand Down
190 changes: 97 additions & 93 deletions hypyp/prep.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import mne
from autoreject import get_rejection_threshold, AutoReject, RejectLog
from mne.preprocessing import ICA, corrmap
from typing import List, Tuple, TypedDict
from typing import List, Tuple, TypedDict, Union

class DicAR(TypedDict):
"""
Expand All @@ -31,7 +31,10 @@ class DicAR(TypedDict):
dyad: float


def filt(raw_S: List[mne.io.Raw]) -> List[mne.io.Raw]:
def filt(
raw_S: List[mne.io.Raw],
freqs: Tuple[Union[float, None], Union[float, None]] = (2., None)
) -> List[mne.io.Raw]:
"""
Filters list of raw data to remove slow drifts.
Expand All @@ -42,11 +45,7 @@ def filt(raw_S: List[mne.io.Raw]) -> List[mne.io.Raw]:
Returns:
raws: list of high-pass filtered raws.
"""
# TODO: l_freq and h_freq as param
raws: List[mne.io.Raw] = []
for raw in raw_S:
raws.append(mne.io.Raw.filter(raw, l_freq=2., h_freq=None))

raws = [mne.io.Raw.filter(raw, l_freq=freqs[0], h_freq=freqs[1]) for raw in raw_S]
return raws


Expand Down Expand Up @@ -77,68 +76,71 @@ def ICA_choice_comp(icas: List[ICA], epochs: List[mne.Epochs]) -> List[mne.Epoch

# choosing participant and its component as a template for the other participant
# if do not want to apply ICA on the data, do not fill the answer
subj_numb = input("Which participant ICA do you want"
" to use as a template for artifact rejection?"
" Index begins at zero. (If you do not want to apply"
" ICA on your data, do not enter nothing and press enter.)")
comp_number = input("Which IC do you want to use as a template?"
subject_id = input("Which participant ICA do you want"
" to use as a template for artifact rejection?"
" Index begins at zero. (If you do not want to apply"
" ICA on your data, do not enter nothing and press enter.)")

component_id = input("Which IC do you want to use as a template?"
" Index begins at zero. (If you did not choose"
"a participant number at first question,"
"then do not enter nothing and press enter again"
"to not apply ICA on your data)")

# applying ICA
if (len(subj_numb) != 0 and len(comp_number) != 0):
cleaned_epochs_ICA = ICA_apply(icas,
int(subj_numb),
int(comp_number),
epochs)
else:
cleaned_epochs_ICA = epochs
" a participant number at first question,"
" then do not enter nothing and press enter again"
" to not apply ICA on your data)")

return cleaned_epochs_ICA
if (len(subject_id) == 0 or len(component_id) == 0):
return epochs

return ICA_apply(icas, int(subject_id), int(component_id), epochs)


def ICA_apply(icas: List[ICA], subj_number: int, comp_number: int, epochs: List[mne.Epochs]) -> List[mne.Epochs]:
def ICA_apply(icas: List[ICA], subject_id: int, component_id: int, epochs: List[mne.Epochs], plot: bool = True) -> List[mne.Epochs]:
"""
Applies ICA with template model from 1 participant in the dyad.
See ICA_choice_comp for a detailed description of the parameters and output.
"""

cleaned_epochs_ICA: List[ICA] = []

# selecting which ICs corresponding to the template
template_eog_component = icas[subj_number].get_components()[:, comp_number]
template_eog_component = icas[subject_id].get_components()[:, component_id]

# applying corrmap with at least 1 component detected for each subj
fig_template, fig_detected = corrmap(icas,
template=template_eog_component,
threshold=0.9,
label='blink',
ch_type='eeg')
corrmap(icas,
template=template_eog_component,
threshold=0.9,
label='blink',
ch_type='eeg',
plot=plot,
)

# labeling the ICs that capture blink artifacts
print([ica.labels_ for ica in icas])

# selecting ICA components after viz
for i in icas:
i.exclude = i.labels_['blink']
for ica in icas:
ica.exclude = ica.labels_['blink']

epoch_all_ch = []
# applying ica on clean_epochs
# for each participant
for i, j in zip(range(0, len(epochs)), icas):
for subject_id, ica in zip(range(0, len(epochs)), icas):
# taking all channels to apply ICA
bads = epochs[i].info['bads']
epoch_all_ch.append(mne.Epochs.copy(epochs[i]))
epoch_all_ch[i].info['bads'] = []
j.apply(epoch_all_ch[i])
epoch_all_ch[i].info['bads'] = bads
cleaned_epochs_ICA.append(epoch_all_ch[i])
epochs_subj = mne.Epochs.copy(epochs[subject_id])
bads_keep = epochs_subj.info['bads']
epochs_subj.info['bads'] = []
ica.apply(epochs_subj)
epochs_subj.info['bads'] = bads_keep
cleaned_epochs_ICA.append(epochs_subj)

return cleaned_epochs_ICA


def ICA_fit(epochs: List[mne.Epochs], n_components: int, method: str, fit_params: dict, random_state: int) -> List[ICA]:
def ICA_fit(
epochs: List[mne.Epochs],
n_components: int,
method: str,
fit_params: dict,
random_state: int
) -> List[ICA]:
"""
Computes global Autorejection to fit Independent Components Analysis
on Epochs, for each participant.
Expand Down Expand Up @@ -178,23 +180,24 @@ def ICA_fit(epochs: List[mne.Epochs], n_components: int, method: str, fit_params
objects, see MNE documentation for more details).
"""
icas: List[ICA] = []
for epoch in epochs:
for epochs_subj in epochs:
# per subj
# applying AR to find global rejection threshold
reject = get_rejection_threshold(epoch, ch_types='eeg')
reject = get_rejection_threshold(epochs_subj, ch_types='eeg')
# if very long, can change decim value
print('The rejection dictionary is %s' % reject)
print(f"The rejection dictionary is {reject}")

# fitting ICA on filt_raw after AR
ica = ICA(n_components=n_components,
method=method,
fit_params= fit_params,
random_state=random_state).fit(epoch)
random_state=random_state)

# take bad channels into account in ICA fit
epoch_all_ch = mne.Epochs.copy(epoch)
epoch_all_ch.info['bads'] = []
epoch_all_ch.drop_bad(reject=reject, flat=None)
icas.append(ica.fit(epoch_all_ch))
epochs_fit = mne.Epochs.copy(epochs_subj)
epochs_fit.info['bads'] = []
epochs_fit.drop_bad(reject=reject, flat=None)
icas.append(ica.fit(epochs_fit))

return icas

Expand Down Expand Up @@ -227,7 +230,7 @@ def AR_local(cleaned_epochs_ICA: List[mne.Epochs], strategy: str = 'union', thre
for each subject and for the intersection of the them.
"""

bad_epochs_AR: List[RejectLog] = []
reject_logs: List[RejectLog] = []
AR: List[AutoReject] = []
dic_AR: DicAR = {}
dic_AR['strategy'] = strategy
Expand All @@ -241,89 +244,90 @@ def AR_local(cleaned_epochs_ICA: List[mne.Epochs], strategy: str = 'union', thre
# n_interpolates = np.array([1, 4, 8, 16, 32, 64])
# consensus_percs = np.linspace(0.5, 1.0, 11)

for clean_epochs in cleaned_epochs_ICA: # per subj

for subject_id, clean_epochs_subj in enumerate(cleaned_epochs_ICA): # per subj
picks = mne.pick_types(
clean_epochs[0].info,
clean_epochs_subj[subject_id].info,
meg=False,
eeg=True,
stim=False,
eog=False,
exclude=[])

ar = AutoReject(n_interpolates, consensus_percs, picks=picks,
thresh_method='random_search', random_state=42,
verbose='tqdm_notebook')
AR.append(ar)
ar = AutoReject(n_interpolates,
consensus_percs,
picks=picks,
thresh_method='random_search',
random_state=42,
verbose='tqdm_notebook')

# fitting AR to get bad epochs
ar.fit(clean_epochs)
reject_log = ar.get_reject_log(clean_epochs, picks=picks)
bad_epochs_AR.append(reject_log)
ar.fit(clean_epochs_subj)
reject_log = ar.get_reject_log(clean_epochs_subj, picks=picks)

# taking bad epochs for min 1 subj (dyad)
log1 = bad_epochs_AR[0]
log2 = bad_epochs_AR[1]
AR.append(ar)
reject_logs.append(reject_log)

bad1 = np.where(log1.bad_epochs == True)
bad2 = np.where(log2.bad_epochs == True)
# taking bad epochs for min 1 subj (dyad)
bad1_idx = np.where(reject_logs[0].bad_epochs == True)[0].tolist()
bad2_idx = np.where(reject_logs[1].bad_epochs == True)[0].tolist()

if strategy == 'union':
bad = list(set(bad1[0].tolist()).union(set(bad2[0].tolist())))
bad_idx = list(set(bad1_idx).union(set(bad2_idx)))
elif strategy == 'intersection':
bad = list(set(bad1[0].tolist()).intersection(set(bad2[0].tolist())))
bad_idx = list(set(bad1_idx).intersection(set(bad2_idx)))
else:
TypeError('not good strategy input!')
raise RuntimeError('not good strategy input!')

# storing the percentage of epochs rejection
dic_AR['S1'] = float((len(bad1[0].tolist())/len(cleaned_epochs_ICA[0]))*100)
dic_AR['S2'] = float((len(bad2[0].tolist())/len(cleaned_epochs_ICA[1]))*100)
dic_AR['S1'] = float((len(bad1_idx) / len(cleaned_epochs_ICA[0])) * 100)
dic_AR['S2'] = float((len(bad2_idx) / len(cleaned_epochs_ICA[1])) * 100)

# picking good epochs for the two subj
cleaned_epochs_AR: List[mne.Epochs] = []
for clean_epochs in cleaned_epochs_ICA: # per subj
for subject_id, clean_epochs_subj in enumerate(cleaned_epochs_ICA): # per subj
# keep a copy of the original data
clean_epochs_ep = copy.deepcopy(clean_epochs)
clean_epochs_ep = clean_epochs_ep.drop(indices=bad)
epochs_subj = copy.deepcopy(clean_epochs_subj)
epochs_subj.drop(indices=bad_idx)
# interpolating bads or removing epochs
ar = AR[cleaned_epochs_ICA.index(clean_epochs)]
clean_epochs_AR = ar.transform(clean_epochs_ep)
cleaned_epochs_AR.append(clean_epochs_AR)
ar = AR[subject_id]
epochs_AR = ar.transform(epochs_subj)
cleaned_epochs_AR.append(epochs_AR)

if strategy == 'intersection':
# equalizing epochs length between two participants
mne.epochs.equalize_epoch_counts(cleaned_epochs_AR)

dic_AR['dyad'] = float(((len(cleaned_epochs_ICA[0])-len(cleaned_epochs_AR[0]))/len(cleaned_epochs_ICA[0]))*100)
n_epochs_ICA = len(cleaned_epochs_ICA[0])
n_epochs_AR = len(cleaned_epochs_AR[0])

dic_AR['dyad'] = float(((n_epochs_ICA - n_epochs_AR) / n_epochs_ICA) * 100)
if dic_AR['dyad'] >= threshold:
TypeError('percentage of rejected epochs above threshold!')
raise RuntimeError(f"percentage of rejected epochs ({dic_AR['dyad']}) above threshold ({threshold})! ")
if verbose:
print('%s percent of bad epochs' % dic_AR['dyad'])
print(f"{dic_AR['dyad']} percent of bad epochs")

# Vizualisation before after AR
evoked_before: List[mne.Evoked] = []
for clean_epochs in cleaned_epochs_ICA: # per subj
evoked_before.append(clean_epochs.average())

evoked_after_AR: List[mne.Evoked] = []
for clean in cleaned_epochs_AR:
evoked_after_AR.append(clean.average())
evoked_before: List[mne.Evoked] = [epochs.average() for epochs in cleaned_epochs_ICA]
evoked_after_AR: List[mne.Evoked] = [epochs.average() for epochs in cleaned_epochs_AR]

if verbose:
for i, j in zip(evoked_before, evoked_after_AR):
for evoked_before_subj, evoked_after_AR_subj in zip(evoked_before, evoked_after_AR):
fig, axes = plt.subplots(2, 1, figsize=(6, 6))
for ax in axes:
ax.tick_params(axis='x', which='both', bottom='off', top='off')
ax.tick_params(axis='y', which='both', left='off', right='off')

ylim = dict(grad=(-170, 200))
i.pick_types(eeg=True, exclude=[])
i.plot(exclude=[], axes=axes[0], ylim=ylim, show=False)

evoked_before_subj.pick_types(eeg=True, exclude=[])
evoked_before_subj.plot(exclude=[], axes=axes[0], ylim=ylim, show=False)
axes[0].set_title('Before autoreject')
j.pick_types(eeg=True, exclude=[])
j.plot(exclude=[], axes=axes[1], ylim=ylim)

evoked_after_AR_subj.pick_types(eeg=True, exclude=[])
evoked_after_AR_subj.plot(exclude=[], axes=axes[1], ylim=ylim)
# Problème titre ne s'affiche pas pour le deuxieme axe !!!
axes[1].set_title('After autoreject')

plt.tight_layout()

return cleaned_epochs_AR, dic_AR
2 changes: 1 addition & 1 deletion hypyp/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -1735,7 +1735,7 @@ def plot_xwt(sig1: mne.Epochs, sig2: mne.Epochs,
plt.imshow(data, aspect='auto', cmap=my_cm, interpolation='lanczos')

else:
ValueError('Analysis must be set as phase, power, or wtc.')
raise ValueError('Analysis must be set as phase, power, or wtc.')

plt.gca().invert_yaxis()
plt.ylabel('Frequencies (Hz)')
Expand Down
Loading

0 comments on commit 52e9cd6

Please sign in to comment.