diff --git a/hypyp/analyses.py b/hypyp/analyses.py index cc6afb4..844fcd5 100644 --- a/hypyp/analyses.py +++ b/hypyp/analyses.py @@ -347,7 +347,7 @@ def pair_connectivity(data: Union[list, np.ndarray], sampling_rate: int, frequen elif type(frequencies) == dict: values = compute_freq_bands(data, sampling_rate, frequencies) else: - TypeError("Please use a list or a dictionary to specify frequencies.") + raise TypeError("Please use a list or a dictionary to specify frequencies.") # compute connectivity values result = compute_sync(values, mode, epochs_average) @@ -511,7 +511,7 @@ def compute_sync(complex_signal: np.ndarray, mode: str, epochs_average: bool = T con = con_num / con_den else: - ValueError('Metric type not supported.') + raise ValueError('Metric type not supported.') con = con.swapaxes(0, 1) # n_freq x n_epoch x 2*n_ch x 2*n_ch if epochs_average: diff --git a/hypyp/prep.py b/hypyp/prep.py index abb65c1..ac3190b 100644 --- a/hypyp/prep.py +++ b/hypyp/prep.py @@ -18,7 +18,7 @@ import mne from autoreject import get_rejection_threshold, AutoReject, RejectLog from mne.preprocessing import ICA, corrmap -from typing import List, Tuple, TypedDict +from typing import List, Tuple, TypedDict, Union class DicAR(TypedDict): """ @@ -31,7 +31,10 @@ class DicAR(TypedDict): dyad: float -def filt(raw_S: List[mne.io.Raw]) -> List[mne.io.Raw]: +def filt( + raw_S: List[mne.io.Raw], + freqs: Tuple[Union[float, None], Union[float, None]] = (2., None) +) -> List[mne.io.Raw]: """ Filters list of raw data to remove slow drifts. @@ -42,11 +45,8 @@ def filt(raw_S: List[mne.io.Raw]) -> List[mne.io.Raw]: Returns: raws: list of high-pass filtered raws. """ - # TODO: l_freq and h_freq as param - raws: List[mne.io.Raw] = [] - for raw in raw_S: - raws.append(mne.io.Raw.filter(raw, l_freq=2., h_freq=None)) - + + raws = [mne.io.Raw.filter(raw, l_freq=freqs[0], h_freq=freqs[1]) for raw in raw_S] return raws @@ -77,68 +77,71 @@ def ICA_choice_comp(icas: List[ICA], epochs: List[mne.Epochs]) -> List[mne.Epoch # choosing participant and its component as a template for the other participant # if do not want to apply ICA on the data, do not fill the answer - subj_numb = input("Which participant ICA do you want" - " to use as a template for artifact rejection?" - " Index begins at zero. (If you do not want to apply" - " ICA on your data, do not enter nothing and press enter.)") - comp_number = input("Which IC do you want to use as a template?" + subject_id = input("Which participant ICA do you want" + " to use as a template for artifact rejection?" + " Index begins at zero. (If you do not want to apply" + " ICA on your data, do not enter nothing and press enter.)") + + component_id = input("Which IC do you want to use as a template?" " Index begins at zero. (If you did not choose" - "a participant number at first question," - "then do not enter nothing and press enter again" - "to not apply ICA on your data)") - - # applying ICA - if (len(subj_numb) != 0 and len(comp_number) != 0): - cleaned_epochs_ICA = ICA_apply(icas, - int(subj_numb), - int(comp_number), - epochs) - else: - cleaned_epochs_ICA = epochs + " a participant number at first question," + " then do not enter nothing and press enter again" + " to not apply ICA on your data)") - return cleaned_epochs_ICA + if (len(subject_id) == 0 or len(component_id) == 0): + return epochs + + return ICA_apply(icas, int(subject_id), int(component_id), epochs) -def ICA_apply(icas: List[ICA], subj_number: int, comp_number: int, epochs: List[mne.Epochs]) -> List[mne.Epochs]: +def ICA_apply(icas: List[ICA], subject_id: int, component_id: int, epochs: List[mne.Epochs], plot: bool = True) -> List[mne.Epochs]: """ Applies ICA with template model from 1 participant in the dyad. See ICA_choice_comp for a detailed description of the parameters and output. """ cleaned_epochs_ICA: List[ICA] = [] + # selecting which ICs corresponding to the template - template_eog_component = icas[subj_number].get_components()[:, comp_number] + template_eog_component = icas[subject_id].get_components()[:, component_id] # applying corrmap with at least 1 component detected for each subj - fig_template, fig_detected = corrmap(icas, - template=template_eog_component, - threshold=0.9, - label='blink', - ch_type='eeg') + corrmap(icas, + template=template_eog_component, + threshold=0.9, + label='blink', + ch_type='eeg', + plot=plot, + ) # labeling the ICs that capture blink artifacts print([ica.labels_ for ica in icas]) # selecting ICA components after viz - for i in icas: - i.exclude = i.labels_['blink'] + for ica in icas: + ica.exclude = ica.labels_['blink'] - epoch_all_ch = [] # applying ica on clean_epochs # for each participant - for i, j in zip(range(0, len(epochs)), icas): + for subject_id, ica in zip(range(0, len(epochs)), icas): # taking all channels to apply ICA - bads = epochs[i].info['bads'] - epoch_all_ch.append(mne.Epochs.copy(epochs[i])) - epoch_all_ch[i].info['bads'] = [] - j.apply(epoch_all_ch[i]) - epoch_all_ch[i].info['bads'] = bads - cleaned_epochs_ICA.append(epoch_all_ch[i]) + epochs_subj = mne.Epochs.copy(epochs[subject_id]) + bads_keep = epochs_subj.info['bads'] + epochs_subj.info['bads'] = [] + ica.apply(epochs_subj) + epochs_subj.info['bads'] = bads_keep + cleaned_epochs_ICA.append(epochs_subj) return cleaned_epochs_ICA -def ICA_fit(epochs: List[mne.Epochs], n_components: int, method: str, fit_params: dict, random_state: int) -> List[ICA]: +def ICA_fit( + epochs: List[mne.Epochs], + n_components: int, + method: str, + fit_params: dict, + random_state: int +) -> List[ICA]: """ Computes global Autorejection to fit Independent Components Analysis on Epochs, for each participant. @@ -178,23 +181,24 @@ def ICA_fit(epochs: List[mne.Epochs], n_components: int, method: str, fit_params objects, see MNE documentation for more details). """ icas: List[ICA] = [] - for epoch in epochs: + for epochs_subj in epochs: # per subj # applying AR to find global rejection threshold - reject = get_rejection_threshold(epoch, ch_types='eeg') + reject = get_rejection_threshold(epochs_subj, ch_types='eeg') # if very long, can change decim value - print('The rejection dictionary is %s' % reject) + print(f"The rejection dictionary is {reject}") # fitting ICA on filt_raw after AR ica = ICA(n_components=n_components, method=method, fit_params= fit_params, - random_state=random_state).fit(epoch) + random_state=random_state) + # take bad channels into account in ICA fit - epoch_all_ch = mne.Epochs.copy(epoch) - epoch_all_ch.info['bads'] = [] - epoch_all_ch.drop_bad(reject=reject, flat=None) - icas.append(ica.fit(epoch_all_ch)) + epochs_fit = mne.Epochs.copy(epochs_subj) + epochs_fit.info['bads'] = [] + epochs_fit.drop_bad(reject=reject, flat=None) + icas.append(ica.fit(epochs_fit)) return icas @@ -227,7 +231,7 @@ def AR_local(cleaned_epochs_ICA: List[mne.Epochs], strategy: str = 'union', thre for each subject and for the intersection of the them. """ - bad_epochs_AR: List[RejectLog] = [] + reject_logs: List[RejectLog] = [] AR: List[AutoReject] = [] dic_AR: DicAR = {} dic_AR['strategy'] = strategy @@ -241,89 +245,91 @@ def AR_local(cleaned_epochs_ICA: List[mne.Epochs], strategy: str = 'union', thre # n_interpolates = np.array([1, 4, 8, 16, 32, 64]) # consensus_percs = np.linspace(0.5, 1.0, 11) - for clean_epochs in cleaned_epochs_ICA: # per subj - + for subject_id, clean_epochs_subj in enumerate(cleaned_epochs_ICA): # per subj picks = mne.pick_types( - clean_epochs[0].info, + clean_epochs_subj[subject_id].info, meg=False, eeg=True, stim=False, eog=False, exclude=[]) - ar = AutoReject(n_interpolates, consensus_percs, picks=picks, - thresh_method='random_search', random_state=42, - verbose='tqdm_notebook') - AR.append(ar) + ar = AutoReject(n_interpolates, + consensus_percs, + picks=picks, + thresh_method='random_search', + random_state=42, + verbose='tqdm_notebook') # fitting AR to get bad epochs - ar.fit(clean_epochs) - reject_log = ar.get_reject_log(clean_epochs, picks=picks) - bad_epochs_AR.append(reject_log) + ar.fit(clean_epochs_subj) + reject_log = ar.get_reject_log(clean_epochs_subj, picks=picks) - # taking bad epochs for min 1 subj (dyad) - log1 = bad_epochs_AR[0] - log2 = bad_epochs_AR[1] + AR.append(ar) + reject_logs.append(reject_log) - bad1 = np.where(log1.bad_epochs == True) - bad2 = np.where(log2.bad_epochs == True) + # taking bad epochs for min 1 subj (dyad) + bad1_idx = np.where(reject_logs[0].bad_epochs == True)[0].tolist() + bad2_idx = np.where(reject_logs[1].bad_epochs == True)[0].tolist() if strategy == 'union': - bad = list(set(bad1[0].tolist()).union(set(bad2[0].tolist()))) + bad_idx = list(set(bad1_idx).union(set(bad2_idx))) elif strategy == 'intersection': - bad = list(set(bad1[0].tolist()).intersection(set(bad2[0].tolist()))) + bad_idx = list(set(bad1_idx).intersection(set(bad2_idx))) else: - TypeError('not good strategy input!') + raise RuntimeError('not good strategy input!') # storing the percentage of epochs rejection - dic_AR['S1'] = float((len(bad1[0].tolist())/len(cleaned_epochs_ICA[0]))*100) - dic_AR['S2'] = float((len(bad2[0].tolist())/len(cleaned_epochs_ICA[1]))*100) + dic_AR['S1'] = float((len(bad1_idx) / len(cleaned_epochs_ICA[0])) * 100) + dic_AR['S2'] = float((len(bad2_idx) / len(cleaned_epochs_ICA[1])) * 100) # picking good epochs for the two subj cleaned_epochs_AR: List[mne.Epochs] = [] - for clean_epochs in cleaned_epochs_ICA: # per subj + + for subject_id, clean_epochs_subj in enumerate(cleaned_epochs_ICA): # per subj # keep a copy of the original data - clean_epochs_ep = copy.deepcopy(clean_epochs) - clean_epochs_ep = clean_epochs_ep.drop(indices=bad) + epochs_subj = copy.deepcopy(clean_epochs_subj) + epochs_subj.drop(indices=bad_idx) # interpolating bads or removing epochs - ar = AR[cleaned_epochs_ICA.index(clean_epochs)] - clean_epochs_AR = ar.transform(clean_epochs_ep) - cleaned_epochs_AR.append(clean_epochs_AR) + ar = AR[subject_id] + epochs_AR = ar.transform(epochs_subj) + cleaned_epochs_AR.append(epochs_AR) if strategy == 'intersection': # equalizing epochs length between two participants mne.epochs.equalize_epoch_counts(cleaned_epochs_AR) - dic_AR['dyad'] = float(((len(cleaned_epochs_ICA[0])-len(cleaned_epochs_AR[0]))/len(cleaned_epochs_ICA[0]))*100) + n_epochs_ICA = len(cleaned_epochs_ICA[0]) + n_epochs_AR = len(cleaned_epochs_AR[0]) + + dic_AR['dyad'] = float(((n_epochs_ICA - n_epochs_AR) / n_epochs_ICA) * 100) if dic_AR['dyad'] >= threshold: - TypeError('percentage of rejected epochs above threshold!') + raise RuntimeError(f"percentage of rejected epochs ({dic_AR['dyad']}) above threshold ({threshold})! ") if verbose: - print('%s percent of bad epochs' % dic_AR['dyad']) + print(f"{dic_AR['dyad']} percent of bad epochs") # Vizualisation before after AR - evoked_before: List[mne.Evoked] = [] - for clean_epochs in cleaned_epochs_ICA: # per subj - evoked_before.append(clean_epochs.average()) - - evoked_after_AR: List[mne.Evoked] = [] - for clean in cleaned_epochs_AR: - evoked_after_AR.append(clean.average()) + evoked_before: List[mne.Evoked] = [epochs.average() for epochs in cleaned_epochs_ICA] + evoked_after_AR: List[mne.Evoked] = [epochs.average() for epochs in cleaned_epochs_AR] if verbose: - for i, j in zip(evoked_before, evoked_after_AR): + for evoked_before_subj, evoked_after_AR_subj in zip(evoked_before, evoked_after_AR): fig, axes = plt.subplots(2, 1, figsize=(6, 6)) for ax in axes: ax.tick_params(axis='x', which='both', bottom='off', top='off') ax.tick_params(axis='y', which='both', left='off', right='off') ylim = dict(grad=(-170, 200)) - i.pick_types(eeg=True, exclude=[]) - i.plot(exclude=[], axes=axes[0], ylim=ylim, show=False) + + evoked_before_subj.pick_types(eeg=True, exclude=[]) + evoked_before_subj.plot(exclude=[], axes=axes[0], ylim=ylim, show=False) axes[0].set_title('Before autoreject') - j.pick_types(eeg=True, exclude=[]) - j.plot(exclude=[], axes=axes[1], ylim=ylim) + + evoked_after_AR_subj.pick_types(eeg=True, exclude=[]) + evoked_after_AR_subj.plot(exclude=[], axes=axes[1], ylim=ylim) # Problème titre ne s'affiche pas pour le deuxieme axe !!! axes[1].set_title('After autoreject') + plt.tight_layout() return cleaned_epochs_AR, dic_AR diff --git a/hypyp/viz.py b/hypyp/viz.py index 083b513..294f4fa 100644 --- a/hypyp/viz.py +++ b/hypyp/viz.py @@ -1735,7 +1735,7 @@ def plot_xwt(sig1: mne.Epochs, sig2: mne.Epochs, plt.imshow(data, aspect='auto', cmap=my_cm, interpolation='lanczos') else: - ValueError('Analysis must be set as phase, power, or wtc.') + raise ValueError('Analysis must be set as phase, power, or wtc.') plt.gca().invert_yaxis() plt.ylabel('Frequencies (Hz)') diff --git a/tests/test_prep.py b/tests/test_prep.py new file mode 100644 index 0000000..b8a785e --- /dev/null +++ b/tests/test_prep.py @@ -0,0 +1,90 @@ +import pytest + +import mne +import numpy as np + +from hypyp import prep + +def test_filt(): + n_channels = 10 + n_samples = 100000 + sfreq = 1000 + + # create noise signal + info = mne.create_info(ch_names=n_channels, sfreq=sfreq, ch_types='eeg') + data = np.random.normal(size=(n_channels, n_samples)) + raw = mne.io.RawArray(data, info) + raw_psd = raw.compute_psd().get_data() + + # compare on lowest freq + raw_filt_default, = prep.filt([raw]) + raw_filt_default_psd = raw_filt_default.compute_psd().get_data() + assert np.sum(raw_filt_default_psd[:,0]) < np.sum(raw_psd[:,0]) + + # compare on highest freq + raw_filt, = prep.filt([raw], (2., 10)) + raw_filt_psd = raw_filt.compute_psd().get_data() + assert np.sum(raw_filt_psd[:,-1]) < np.sum(raw_filt_default_psd[:,-1]) + + +@pytest.mark.parametrize("fit_kwargs", [ + dict(method='fastica', fit_params=dict(tol=0.01)), # increase tolerance to converge + dict(method='infomax', fit_params=dict(extended=True)) +]) +def test_ICA(epochs, fit_kwargs): + ep = [epochs.epo1, epochs.epo2] + icas = prep.ICA_fit(ep, n_components=15, **fit_kwargs, random_state=97) + + assert len(icas) == len(ep) + + # check that the number of componenents is similar between the two participants + for i in range(0, len(icas)-1): + assert mne.preprocessing.ICA.get_components( + icas[i]).shape == mne.preprocessing.ICA.get_components(icas[i+1]).shape + + cleaned_epochs_ICA = prep.ICA_apply(icas, 0, 0, ep, plot=False) + + # check bad channels are not deleted + assert epochs.epo1.info['ch_names'] == cleaned_epochs_ICA[0].info['ch_names'] + assert epochs.epo2.info['ch_names'] == cleaned_epochs_ICA[1].info['ch_names'] + + # check signal change by comparing total amplitude + raw_amplitudes = np.mean(np.abs(epochs.epo1.get_data(copy=True)), axis=1) + processed_amplitudes = np.mean(np.abs(cleaned_epochs_ICA[0].get_data(copy=True)), axis=1) + assert np.sum(processed_amplitudes) < np.sum(raw_amplitudes) + + +@pytest.mark.parametrize("AR_local_kwargs", [ + dict(strategy='union'), + dict(strategy='intersection'), +]) +def test_AR_local(epochs, AR_local_kwargs): + # test on epochs, but usually applied on cleaned epochs with ICA + cleaned_epochs_AR, dic_AR = prep.AR_local( + [epochs.epo1, epochs.epo2], + **AR_local_kwargs, + threshold=50.0, + verbose=False + ) + assert len(epochs.epo1) >= len(cleaned_epochs_AR[0]) + assert len(epochs.epo2) >= len(cleaned_epochs_AR[1]) + assert len(cleaned_epochs_AR[0]) == len(cleaned_epochs_AR[1]) + assert dic_AR['S2'] + dic_AR['S1'] == dic_AR['dyad'] + assert dic_AR['S2'] <= dic_AR['dyad'] + + +@pytest.mark.parametrize("AR_local_kwargs", [ + dict(strategy='union'), + dict(strategy='intersection'), +]) +def test_AR_local_exception(epochs, AR_local_kwargs): + # test the threshold + with pytest.raises(Exception): + prep.AR_local( + [epochs.epo1, epochs.epo2], + **AR_local_kwargs, + threshold=0.0, + verbose=False + ) + + diff --git a/tests/test_stats.py b/tests/test_stats.py index 29b1959..f1d7cb5 100644 --- a/tests/test_stats.py +++ b/tests/test_stats.py @@ -6,7 +6,6 @@ import numpy as np import scipy import mne -from hypyp import prep from hypyp import stats from hypyp import utils from hypyp import analyses @@ -60,44 +59,6 @@ def test_metaconn_matrix_2brains(epochs): assert metaconn_freq[n+tot*i, p+tot*i] == ch_con_freq[n, p-tot] -def test_ICA(epochs): - """ - Test ICA fit, ICA choice comp and ICA apply - """ - ep = [epochs.epo1, epochs.epo2] - icas = prep.ICA_fit(ep, n_components=15, method='infomax', fit_params=dict(extended=True), - random_state=97) - # check that the number of componenents is similar between the two participants - for i in range(0, len(icas)-1): - mne.preprocessing.ICA.get_components( - icas[i]).shape == mne.preprocessing.ICA.get_components(icas[i+1]).shape - # cleaned_epochs_ICA = prep.ICA_choice_comp(icas, ep) # pb interactive window - # check signal better after than before - # check bad channels are not deleted - # assert epochs.epo1.info['ch_names'] == cleaned_epochs_ICA[0].info['ch_names'] - # assert epochs.epo2.info['ch_names'] == cleaned_epochs_ICA[1].info['ch_names'] - - -def test_AR_local(epochs): - """ - Test AR local - """ - # test on epochs, but usually applied on cleaned epochs with ICA - ep = [epochs.epo1, epochs.epo2] - cleaned_epochs_AR, dic_AR = prep.AR_local( - ep, strategy='union', threshold=50.0, verbose=False) - assert len(epochs.epo1) >= len(cleaned_epochs_AR[0]) - assert len(epochs.epo2) >= len(cleaned_epochs_AR[1]) - assert len(cleaned_epochs_AR[0]) == len(cleaned_epochs_AR[1]) - assert dic_AR['S2'] + dic_AR['S1'] == dic_AR['dyad'] - cleaned_epochs_AR, dic_AR = prep.AR_local( - ep, strategy='intersection', threshold=50.0, verbose=False) - assert dic_AR['S2'] <= dic_AR['dyad'] - cleaned_epochs_AR, dic_AR = prep.AR_local( - ep, strategy='intersection', threshold=0.0, verbose=False) - # should print an error - - def test_PSD(epochs): """ Test PSD