From b9a4acae708ee58a0a6256e15e03ec8ed59c581c Mon Sep 17 00:00:00 2001 From: Alex Rockhill Date: Wed, 3 Jan 2024 16:04:41 -0800 Subject: [PATCH 1/8] [ENH, MRG] Add interpolation per epoch [skip ci] --- .../preprocessing/interpolate_bad_channels.py | 21 ++++++++ mne/channels/channels.py | 26 ++-------- mne/epochs.py | 51 +++++++++++++++++++ mne/tests/test_epochs.py | 34 +++++++++++++ mne/utils/docs.py | 31 +++++++++++ 5 files changed, 141 insertions(+), 22 deletions(-) diff --git a/examples/preprocessing/interpolate_bad_channels.py b/examples/preprocessing/interpolate_bad_channels.py index a56aec7d8f7..75a53bc8fab 100644 --- a/examples/preprocessing/interpolate_bad_channels.py +++ b/examples/preprocessing/interpolate_bad_channels.py @@ -23,6 +23,7 @@ # sphinx_gallery_thumbnail_number = 2 +import numpy as np import mne from mne.datasets import sample @@ -48,6 +49,26 @@ ) evoked_interp_mne.plot(exclude=[], picks=("grad", "eeg")) +# %% +# You can also interpolate bad channels per epoch +raw = mne.io.read_raw(meg_path / "sample_audvis_raw.fif") +events = mne.read_events(meg_path / "sample_audvis_raw-eve.fif") +epochs = mne.Epochs(raw, events=events) + +# try to only remove bad channels on some epochs to save EEG 053 +epochs.drop_bad(reject=dict(grad=4000e-13, mag=4e-12, eeg=100e-6)) + +interpolate_channels = [entry if len(entry) < 5 else tuple() + for entry in epochs.drop_log] +drop_epochs = np.array([len(entry) >= 5 for entry in epochs.drop_log]) +del epochs + +epochs = mne.Epochs(raw, events=events) +epochs.interpolate_bads_per_epoch(interpolate_channels) +epochs = epochs[~drop_epochs] +epochs.average().plot() + + # %% # References # ---------- diff --git a/mne/channels/channels.py b/mne/channels/channels.py index 325be7350a6..2b7aac3df92 100644 --- a/mne/channels/channels.py +++ b/mne/channels/channels.py @@ -823,36 +823,18 @@ def interpolate_bads( ---------- reset_bads : bool If True, remove the bads from info. - mode : str - Either ``'accurate'`` or ``'fast'``, determines the quality of the - Legendre polynomial expansion used for interpolation of channels - using the minimum-norm method. - origin : array-like, shape (3,) | str - Origin of the sphere in the head coordinate frame and in meters. - Can be ``'auto'`` (default), which means a head-digitization-based - origin fit. + %(mode_interp)s + %(origin_interp)s .. versionadded:: 0.17 - method : dict | str | None - Method to use for each channel type. - - - ``"meg"`` channels support ``"MNE"`` (default) and ``"nan"`` - - ``"eeg"`` channels support ``"spline"`` (default), ``"MNE"`` and ``"nan"`` - - ``"fnirs"`` channels support ``"nearest"`` (default) and ``"nan"`` - - None is an alias for:: - - method=dict(meg="MNE", eeg="spline", fnirs="nearest") - - If a :class:`str` is provided, the method will be applied to all channel - types supported and available in the instance. The method ``"nan"`` will - replace the channel data with ``np.nan``. + %(method_interp)s .. warning:: Be careful when using ``method="nan"``; the default value ``reset_bads=True`` may not be what you want. .. versionadded:: 0.21 + exclude : list | tuple The channels to exclude from interpolation. If excluded a bad channel will stay in bads. diff --git a/mne/epochs.py b/mne/epochs.py index 34d942536bd..3b45e6f9897 100644 --- a/mne/epochs.py +++ b/mne/epochs.py @@ -1410,6 +1410,57 @@ def drop_bad(self, reject="existing", flat="existing", verbose=None): self._get_data(out=False, verbose=verbose) return self + @verbose + def interpolate_bads_per_epoch( + self, + interp_chs, + mode="accurate", + origin="auto", + method=None, + verbose=None, + ): + """Interpolate bad channels per epoch. + + Parameters + ---------- + interp_chs : list + A list containing channels to interpolate for each epoch. + %(mode_interp)s + %(origin_interp)s + %(method_interp)s + %(verbose)s + + Returns + ------- + epochs : instance of Epochs + The epochs with bad channels interpolated per epoch. Operates in-place. + + Notes + ----- + .. versionadded:: 1.7 + """ + _check_preload(self, "epochs.interpolate_bads_per_epoch") + if len(interp_chs) != len(self): + raise ValueError( + "The length of ``interp_chs`` must be " + f"the same as the the number of epochs ({len(self)}), " + f"got {len(interp_chs)}" + ) + for i, this_interp_ch in enumerate(interp_chs): + chs_not_found = [ch for ch in this_interp_ch if ch not in self.ch_names] + if chs_not_found: + raise ValueError( + f"Channels {chs_not_found} not found " f"for interp_chs[{i}]" + ) + for epoch_idx, this_interp_chs in enumerate(interp_chs): + if this_interp_chs: + logger.debug(f"Interpolating {this_interp_chs} on epoch {epoch_idx}") + epoch = self[epoch_idx] + epoch.info["bads"] = list(this_interp_chs) + epoch.interpolate_bads(mode=mode, origin=origin, method=method) + self._data[epoch_idx] = epoch._data + return self + def drop_log_stats(self, ignore=("IGNORED",)): """Compute the channel stats based on a drop_log from Epochs. diff --git a/mne/tests/test_epochs.py b/mne/tests/test_epochs.py index c68fc7ce6bd..362f0cc67e4 100644 --- a/mne/tests/test_epochs.py +++ b/mne/tests/test_epochs.py @@ -720,6 +720,40 @@ def test_reject(): assert "is a noop" in log +def test_interpolate_bads_per_epochs(): + raw, events, _ = _get_data() + names = raw.ch_names[::5] + assert "MEG 2443" in names + raw.pick(names).load_data() + assert "eog" in raw + raw.info.normalize_proj() + picks = np.arange(len(raw.ch_names)) + # cull the list just to contain the relevant event + events = events[events[:, 2] == event_id, :] + assert len(events) == 7 + # test interpolate bads per epoch + epochs = Epochs( + raw, + events, + event_id, + tmin, + tmax, + picks=picks, + preload=True, + ) + with pytest.raises(ValueError, match="length of ``interp_chs``"): + epochs.interpolate_bads_per_epoch([]) + interp_chs = [tuple() for _ in range(len(epochs))] + interp_chs[0] = ("foo",) + with pytest.raises(ValueError, match="Channels .* not found"): + epochs.interpolate_bads_per_epoch(interp_chs) + + interp_chs[0] = (epochs.ch_names[0],) + data_before = epochs.get_data().copy() + data_after = epochs.interpolate_bads_per_epoch(interp_chs).get_data() + assert not np.array_equal(data_before[0], data_after[0]) + + def test_reject_by_annotations_reject_tmin_reject_tmax(): """Test reject_by_annotations with reject_tmin and reject_tmax defined.""" # 10 seconds of data, event at 2s, bad segment from 1s to 1.5s diff --git a/mne/utils/docs.py b/mne/utils/docs.py index 1fa26fa16dd..c667b20c13c 100644 --- a/mne/utils/docs.py +++ b/mne/utils/docs.py @@ -2221,6 +2221,23 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): forward-backward filtering (via :func:`~scipy.signal.filtfilt`). """ +docdict["method_interp"] = """ +method : dict | str | None + Method to use for each channel type. + + - ``"meg"`` channels support ``"MNE"`` (default) and ``"nan"`` + - ``"eeg"`` channels support ``"spline"`` (default), ``"MNE"`` and ``"nan"`` + - ``"fnirs"`` channels support ``"nearest"`` (default) and ``"nan"`` + + None is an alias for:: + + method=dict(meg="MNE", eeg="spline", fnirs="nearest") + + If a :class:`str` is provided, the method will be applied to all channel + types supported and available in the instance. The method ``"nan"`` will + replace the channel data with ``np.nan``. +""" + docdict["method_kw_psd"] = """\ **method_kw Additional keyword arguments passed to the spectral estimation @@ -2259,6 +2276,13 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): Extraction mode, see Notes. """ +docdict["mode_interp"] = """ +mode : str + Either ``'accurate'`` or ``'fast'``, determines the quality of the + Legendre polynomial expansion used for interpolation of channels + using the minimum-norm method. +""" + docdict["mode_pctf"] = """ mode : None | 'mean' | 'max' | 'svd' Compute summary of PSFs/CTFs across all indices specified in 'idx'. @@ -2658,6 +2682,13 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): The default changed from False in 1.4 to True in 1.5. """ +docdict["origin_interp"] = """ +origin : array-like, shape (3,) | str + Origin of the sphere in the head coordinate frame and in meters. + Can be ``'auto'`` (default), which means a head-digitization-based + origin fit. +""" + docdict["origin_maxwell"] = """ origin : array-like, shape (3,) | str Origin of internal and external multipolar moment space in meters. From a8976a4d143ef41d77afde7584b45df30dae3303 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 4 Jan 2024 00:09:04 +0000 Subject: [PATCH 2/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/preprocessing/interpolate_bad_channels.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/preprocessing/interpolate_bad_channels.py b/examples/preprocessing/interpolate_bad_channels.py index 75a53bc8fab..fe730907105 100644 --- a/examples/preprocessing/interpolate_bad_channels.py +++ b/examples/preprocessing/interpolate_bad_channels.py @@ -24,6 +24,7 @@ # sphinx_gallery_thumbnail_number = 2 import numpy as np + import mne from mne.datasets import sample @@ -58,8 +59,9 @@ # try to only remove bad channels on some epochs to save EEG 053 epochs.drop_bad(reject=dict(grad=4000e-13, mag=4e-12, eeg=100e-6)) -interpolate_channels = [entry if len(entry) < 5 else tuple() - for entry in epochs.drop_log] +interpolate_channels = [ + entry if len(entry) < 5 else tuple() for entry in epochs.drop_log +] drop_epochs = np.array([len(entry) >= 5 for entry in epochs.drop_log]) del epochs From 24a25e5eeb74039b07ccfab21ce752143da3f062 Mon Sep 17 00:00:00 2001 From: Alex Rockhill Date: Wed, 3 Jan 2024 16:10:55 -0800 Subject: [PATCH 3/8] add changelog entry --- doc/changes/devel/12334.newfeature.rst | 1 + 1 file changed, 1 insertion(+) create mode 100644 doc/changes/devel/12334.newfeature.rst diff --git a/doc/changes/devel/12334.newfeature.rst b/doc/changes/devel/12334.newfeature.rst new file mode 100644 index 00000000000..23a330ce6ce --- /dev/null +++ b/doc/changes/devel/12334.newfeature.rst @@ -0,0 +1 @@ +Add :meth:`mne.Epochs.interpolate_bads_per_channel` to interpolate channels only on selected epochs for which they are bad, by `Alex Rockhill`_ \ No newline at end of file From 9cf13b344b24f0eb536c9c2cbca7cb1500131d82 Mon Sep 17 00:00:00 2001 From: Alex Rockhill Date: Wed, 3 Jan 2024 16:11:29 -0800 Subject: [PATCH 4/8] style --- examples/preprocessing/interpolate_bad_channels.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/preprocessing/interpolate_bad_channels.py b/examples/preprocessing/interpolate_bad_channels.py index 75a53bc8fab..f9741c5675c 100644 --- a/examples/preprocessing/interpolate_bad_channels.py +++ b/examples/preprocessing/interpolate_bad_channels.py @@ -58,8 +58,9 @@ # try to only remove bad channels on some epochs to save EEG 053 epochs.drop_bad(reject=dict(grad=4000e-13, mag=4e-12, eeg=100e-6)) -interpolate_channels = [entry if len(entry) < 5 else tuple() - for entry in epochs.drop_log] +interpolate_channels = [ + entry if len(entry) < 5 else tuple() for entry in epochs.drop_log +] drop_epochs = np.array([len(entry) >= 5 for entry in epochs.drop_log]) del epochs From 37c5acd674583ff714e59e1fcdadff322c0d071c Mon Sep 17 00:00:00 2001 From: Alex Rockhill Date: Wed, 3 Jan 2024 16:16:02 -0800 Subject: [PATCH 5/8] add docstring --- mne/tests/test_epochs.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mne/tests/test_epochs.py b/mne/tests/test_epochs.py index 362f0cc67e4..c0782852919 100644 --- a/mne/tests/test_epochs.py +++ b/mne/tests/test_epochs.py @@ -721,6 +721,7 @@ def test_reject(): def test_interpolate_bads_per_epochs(): + """Test interpolating bad channels per epoch.""" raw, events, _ = _get_data() names = raw.ch_names[::5] assert "MEG 2443" in names From df27b947ccc4f4ec55eacdef56be16124d7e860a Mon Sep 17 00:00:00 2001 From: Alex Rockhill Date: Wed, 3 Jan 2024 17:53:52 -0800 Subject: [PATCH 6/8] fix circle, speed up example --- examples/preprocessing/interpolate_bad_channels.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/preprocessing/interpolate_bad_channels.py b/examples/preprocessing/interpolate_bad_channels.py index fe730907105..b189e7ae011 100644 --- a/examples/preprocessing/interpolate_bad_channels.py +++ b/examples/preprocessing/interpolate_bad_channels.py @@ -53,11 +53,12 @@ # %% # You can also interpolate bad channels per epoch raw = mne.io.read_raw(meg_path / "sample_audvis_raw.fif") +raw.pick("eeg") # just to speed up execution events = mne.read_events(meg_path / "sample_audvis_raw-eve.fif") epochs = mne.Epochs(raw, events=events) # try to only remove bad channels on some epochs to save EEG 053 -epochs.drop_bad(reject=dict(grad=4000e-13, mag=4e-12, eeg=100e-6)) +epochs.drop_bad(reject=dict(eeg=100e-6)) interpolate_channels = [ entry if len(entry) < 5 else tuple() for entry in epochs.drop_log @@ -65,12 +66,11 @@ drop_epochs = np.array([len(entry) >= 5 for entry in epochs.drop_log]) del epochs -epochs = mne.Epochs(raw, events=events) +epochs = mne.Epochs(raw, events=events, preload=True) epochs.interpolate_bads_per_epoch(interpolate_channels) epochs = epochs[~drop_epochs] epochs.average().plot() - # %% # References # ---------- From d9b4b83950e2eab419f1b4c6832db50ca453994a Mon Sep 17 00:00:00 2001 From: Alex Rockhill Date: Thu, 4 Jan 2024 08:36:21 -0800 Subject: [PATCH 7/8] saved version roundtrip to disk --- doc/changes/devel/12334.newfeature.rst | 2 +- .../preprocessing/interpolate_bad_channels.py | 3 +- mne/_fiff/meas_info.py | 18 +++++++-- mne/_fiff/pick.py | 9 ++++- mne/_fiff/tag.py | 9 ++++- mne/_fiff/write.py | 38 +++++++++++++++++-- mne/channels/channels.py | 3 +- mne/epochs.py | 14 +++++-- mne/utils/docs.py | 5 +++ 9 files changed, 83 insertions(+), 18 deletions(-) diff --git a/doc/changes/devel/12334.newfeature.rst b/doc/changes/devel/12334.newfeature.rst index 23a330ce6ce..107b8fe8c30 100644 --- a/doc/changes/devel/12334.newfeature.rst +++ b/doc/changes/devel/12334.newfeature.rst @@ -1 +1 @@ -Add :meth:`mne.Epochs.interpolate_bads_per_channel` to interpolate channels only on selected epochs for which they are bad, by `Alex Rockhill`_ \ No newline at end of file +Allow :meth:`mne.Epochs.interpolate_bads` to interpolate channels only on selected epochs for which they are bad, by `Alex Rockhill`_ \ No newline at end of file diff --git a/examples/preprocessing/interpolate_bad_channels.py b/examples/preprocessing/interpolate_bad_channels.py index b189e7ae011..1059408ec55 100644 --- a/examples/preprocessing/interpolate_bad_channels.py +++ b/examples/preprocessing/interpolate_bad_channels.py @@ -67,7 +67,8 @@ del epochs epochs = mne.Epochs(raw, events=events, preload=True) -epochs.interpolate_bads_per_epoch(interpolate_channels) +epochs.info["bads"] = interpolate_channels +epochs.interpolate_bads() epochs = epochs[~drop_epochs] epochs.average().plot() diff --git a/mne/_fiff/meas_info.py b/mne/_fiff/meas_info.py index 483ddc34b52..4d31bfa6915 100644 --- a/mne/_fiff/meas_info.py +++ b/mne/_fiff/meas_info.py @@ -951,10 +951,15 @@ def _check_bads_info_compat(bads, info): if not len(bads): return # e.g. in empty_info for bi, bad in enumerate(bads): - _validate_type(bad, str, f"bads[{bi}]") + _validate_type(bad, (str, list, tuple), f"bads[{bi}]") if "ch_names" not in info: # somewhere in init, or deepcopy, or _empty_info, etc. return - missing = [bad for bad in bads if bad not in info["ch_names"]] + missing = [ + bad + for bads_list in bads + for bad in ([bads_list] if isinstance(bads_list, str) else bads_list) + if bad not in info["ch_names"] + ] if len(missing) > 0: raise ValueError(f"bad channel(s) {missing} marked do not exist in info") @@ -2569,8 +2574,13 @@ def read_meas_info(fid, tree, clean_bads=False, verbose=None): info["dig"] = _format_dig_points(dig) info["bads"] = bads info._update_redundant() - if clean_bads: - info["bads"] = [b for b in bads if b in info["ch_names"]] + if clean_bads and info["bads"]: + if isinstance(info["bads"][0], str): + info["bads"] = [b for b in bads if b in info["ch_names"]] + else: + info["bads"] = [ + [b for b in bads_list if b in info["ch_names"]] for bads_list in bads + ] info["projs"] = projs info["comps"] = comps info["acq_pars"] = acq_pars diff --git a/mne/_fiff/pick.py b/mne/_fiff/pick.py index 4c5854f36fe..269211603ad 100644 --- a/mne/_fiff/pick.py +++ b/mne/_fiff/pick.py @@ -670,7 +670,14 @@ def pick_info(info, sel=(), copy=True, verbose=None): with info._unlock(): info["chs"] = [info["chs"][k] for k in sel] info._update_redundant() - info["bads"] = [ch for ch in info["bads"] if ch in info["ch_names"]] + if info["bads"]: + if isinstance(info["bads"][0], str): + info["bads"] = [ch for ch in info["bads"] if ch in info["ch_names"]] + else: + info["bads"] = [ + [ch for ch in bads_list if ch in info["ch_names"]] + for bads_list in info["bads"] + ] if "comps" in info: comps = deepcopy(info["comps"]) for c in comps: diff --git a/mne/_fiff/tag.py b/mne/_fiff/tag.py index 81ed12baf6f..6b7848a2bbc 100644 --- a/mne/_fiff/tag.py +++ b/mne/_fiff/tag.py @@ -517,7 +517,14 @@ def has_tag(node, kind): def _rename_list(bads, ch_names_mapping): - return [ch_names_mapping.get(bad, bad) for bad in bads] + if not bads: + return bads + if isinstance(bads[0], str): + return [ch_names_mapping.get(bad, bad) for bad in bads] + else: + return [ + [ch_names_mapping.get(bad, bad) for bad in bads_list] for bads_list in bads + ] def _int_item(x): diff --git a/mne/_fiff/write.py b/mne/_fiff/write.py index 3e6621d0069..a35cbd4c82c 100644 --- a/mne/_fiff/write.py +++ b/mne/_fiff/write.py @@ -156,16 +156,46 @@ def write_name_list_sanitized(fid, kind, lst, name): def _safe_name_list(lst, operation, name): if operation == "write": assert isinstance(lst, (list, tuple, np.ndarray)), type(lst) - if any("{COLON}" in val for val in lst): - raise ValueError(f'The substring "{{COLON}}" in {name} not supported.') - return ":".join(val.replace(":", "{COLON}") for val in lst) + if any( + "{COLON}" in val or "{COMMA}" in val + for val2 in lst + for val in ([val2] if isinstance(val2, str) else val2) + ): + raise ValueError( + f'The substring "{{COLON}}" or "{{COMMA}}" in {name} not supported.' + ) + if not lst: + return "" + if isinstance(lst[0], str): + return ":".join(val.replace(":", "{COLON}") for val in lst) + return ",".join( + ":".join( + val.replace(":", "{COLON}").replace(",", "{COMMA}") for val in val2 + ) + for val2 in lst + ) else: # take a sanitized string and return a list of strings assert operation == "read" assert lst is None or isinstance(lst, str) if not lst: # None or empty string return [] - return [val.replace("{COLON}", ":") for val in lst.split(":")] + if "," in lst: + return [ + ( + [ + val.replace("{COLON}", ":").replace("{COMMA}", ",") + for val in val2.split(":") + ] + if val2 + else list() + ) + for val2 in lst.split(",") + ] + return [ + val.replace("{COLON}", ":").replace("{COMMA}", ",") + for val in lst.split(":") + ] def write_float_matrix(fid, kind, mat): diff --git a/mne/channels/channels.py b/mne/channels/channels.py index 2b7aac3df92..5ed3a8967b3 100644 --- a/mne/channels/channels.py +++ b/mne/channels/channels.py @@ -821,8 +821,7 @@ def interpolate_bads( Parameters ---------- - reset_bads : bool - If True, remove the bads from info. + %(reset_bads)s %(mode_interp)s %(origin_interp)s diff --git a/mne/epochs.py b/mne/epochs.py index 3b45e6f9897..f65b3e87e19 100644 --- a/mne/epochs.py +++ b/mne/epochs.py @@ -66,7 +66,12 @@ ) from .baseline import _check_baseline, _log_rescale, rescale from .bem import _check_origin -from .channels.channels import InterpolationMixin, ReferenceMixin, UpdateChannelsMixin +from .channels.channels import ( + InterpolationMixin, + ReferenceMixin, + UpdateChannelsMixin, + interpolate_bads, +) from .event import _read_events_fif, make_fixed_length_events, match_event_names from .evoked import EvokedArray from .filter import FilterMixin, _check_fun, detrend @@ -1410,13 +1415,14 @@ def drop_bad(self, reject="existing", flat="existing", verbose=None): self._get_data(out=False, verbose=verbose) return self - @verbose - def interpolate_bads_per_epoch( + @copy_function_doc_to_method_doc(interpolate_bads) + def interpolate_bads( self, - interp_chs, + reset_bads=True, mode="accurate", origin="auto", method=None, + exclude=(), verbose=None, ): """Interpolate bad channels per epoch. diff --git a/mne/utils/docs.py b/mne/utils/docs.py index c667b20c13c..d90acbc3a17 100644 --- a/mne/utils/docs.py +++ b/mne/utils/docs.py @@ -3357,6 +3357,11 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): The resolution of the topomap image (number of pixels along each side). """ +docdict["reset_bads"] = """ +reset_bads : bool + If True, remove the bads from info. +""" + docdict["return_pca_vars_pctf"] = """ return_pca_vars : bool Whether or not to return the explained variances across the specified From ab01496aa5d45d2dfe0bfc925d68271ed3c30e31 Mon Sep 17 00:00:00 2001 From: Alex Rockhill Date: Thu, 4 Jan 2024 09:07:47 -0800 Subject: [PATCH 8/8] revert info bads changes, overload --- doc/changes/devel/12334.newfeature.rst | 2 +- .../preprocessing/interpolate_bad_channels.py | 3 +- mne/_fiff/meas_info.py | 18 ++------ mne/_fiff/pick.py | 9 +--- mne/_fiff/tag.py | 9 +--- mne/_fiff/write.py | 38 ++-------------- mne/channels/channels.py | 5 +-- mne/epochs.py | 44 ++++++++++++++----- mne/tests/test_epochs.py | 8 ++-- mne/utils/docs.py | 6 +++ 10 files changed, 56 insertions(+), 86 deletions(-) diff --git a/doc/changes/devel/12334.newfeature.rst b/doc/changes/devel/12334.newfeature.rst index 107b8fe8c30..da738a25838 100644 --- a/doc/changes/devel/12334.newfeature.rst +++ b/doc/changes/devel/12334.newfeature.rst @@ -1 +1 @@ -Allow :meth:`mne.Epochs.interpolate_bads` to interpolate channels only on selected epochs for which they are bad, by `Alex Rockhill`_ \ No newline at end of file +Allow :meth:`mne.Epochs.interpolate_bads` to interpolate channels only on selected epochs for which they are bad using ``interp_chs``, by `Alex Rockhill`_ \ No newline at end of file diff --git a/examples/preprocessing/interpolate_bad_channels.py b/examples/preprocessing/interpolate_bad_channels.py index 1059408ec55..969b090d03d 100644 --- a/examples/preprocessing/interpolate_bad_channels.py +++ b/examples/preprocessing/interpolate_bad_channels.py @@ -67,8 +67,7 @@ del epochs epochs = mne.Epochs(raw, events=events, preload=True) -epochs.info["bads"] = interpolate_channels -epochs.interpolate_bads() +epochs.interpolate_bads(interp_chs=interpolate_channels) epochs = epochs[~drop_epochs] epochs.average().plot() diff --git a/mne/_fiff/meas_info.py b/mne/_fiff/meas_info.py index 4d31bfa6915..483ddc34b52 100644 --- a/mne/_fiff/meas_info.py +++ b/mne/_fiff/meas_info.py @@ -951,15 +951,10 @@ def _check_bads_info_compat(bads, info): if not len(bads): return # e.g. in empty_info for bi, bad in enumerate(bads): - _validate_type(bad, (str, list, tuple), f"bads[{bi}]") + _validate_type(bad, str, f"bads[{bi}]") if "ch_names" not in info: # somewhere in init, or deepcopy, or _empty_info, etc. return - missing = [ - bad - for bads_list in bads - for bad in ([bads_list] if isinstance(bads_list, str) else bads_list) - if bad not in info["ch_names"] - ] + missing = [bad for bad in bads if bad not in info["ch_names"]] if len(missing) > 0: raise ValueError(f"bad channel(s) {missing} marked do not exist in info") @@ -2574,13 +2569,8 @@ def read_meas_info(fid, tree, clean_bads=False, verbose=None): info["dig"] = _format_dig_points(dig) info["bads"] = bads info._update_redundant() - if clean_bads and info["bads"]: - if isinstance(info["bads"][0], str): - info["bads"] = [b for b in bads if b in info["ch_names"]] - else: - info["bads"] = [ - [b for b in bads_list if b in info["ch_names"]] for bads_list in bads - ] + if clean_bads: + info["bads"] = [b for b in bads if b in info["ch_names"]] info["projs"] = projs info["comps"] = comps info["acq_pars"] = acq_pars diff --git a/mne/_fiff/pick.py b/mne/_fiff/pick.py index 269211603ad..4c5854f36fe 100644 --- a/mne/_fiff/pick.py +++ b/mne/_fiff/pick.py @@ -670,14 +670,7 @@ def pick_info(info, sel=(), copy=True, verbose=None): with info._unlock(): info["chs"] = [info["chs"][k] for k in sel] info._update_redundant() - if info["bads"]: - if isinstance(info["bads"][0], str): - info["bads"] = [ch for ch in info["bads"] if ch in info["ch_names"]] - else: - info["bads"] = [ - [ch for ch in bads_list if ch in info["ch_names"]] - for bads_list in info["bads"] - ] + info["bads"] = [ch for ch in info["bads"] if ch in info["ch_names"]] if "comps" in info: comps = deepcopy(info["comps"]) for c in comps: diff --git a/mne/_fiff/tag.py b/mne/_fiff/tag.py index 6b7848a2bbc..81ed12baf6f 100644 --- a/mne/_fiff/tag.py +++ b/mne/_fiff/tag.py @@ -517,14 +517,7 @@ def has_tag(node, kind): def _rename_list(bads, ch_names_mapping): - if not bads: - return bads - if isinstance(bads[0], str): - return [ch_names_mapping.get(bad, bad) for bad in bads] - else: - return [ - [ch_names_mapping.get(bad, bad) for bad in bads_list] for bads_list in bads - ] + return [ch_names_mapping.get(bad, bad) for bad in bads] def _int_item(x): diff --git a/mne/_fiff/write.py b/mne/_fiff/write.py index a35cbd4c82c..3e6621d0069 100644 --- a/mne/_fiff/write.py +++ b/mne/_fiff/write.py @@ -156,46 +156,16 @@ def write_name_list_sanitized(fid, kind, lst, name): def _safe_name_list(lst, operation, name): if operation == "write": assert isinstance(lst, (list, tuple, np.ndarray)), type(lst) - if any( - "{COLON}" in val or "{COMMA}" in val - for val2 in lst - for val in ([val2] if isinstance(val2, str) else val2) - ): - raise ValueError( - f'The substring "{{COLON}}" or "{{COMMA}}" in {name} not supported.' - ) - if not lst: - return "" - if isinstance(lst[0], str): - return ":".join(val.replace(":", "{COLON}") for val in lst) - return ",".join( - ":".join( - val.replace(":", "{COLON}").replace(",", "{COMMA}") for val in val2 - ) - for val2 in lst - ) + if any("{COLON}" in val for val in lst): + raise ValueError(f'The substring "{{COLON}}" in {name} not supported.') + return ":".join(val.replace(":", "{COLON}") for val in lst) else: # take a sanitized string and return a list of strings assert operation == "read" assert lst is None or isinstance(lst, str) if not lst: # None or empty string return [] - if "," in lst: - return [ - ( - [ - val.replace("{COLON}", ":").replace("{COMMA}", ",") - for val in val2.split(":") - ] - if val2 - else list() - ) - for val2 in lst.split(",") - ] - return [ - val.replace("{COLON}", ":").replace("{COMMA}", ",") - for val in lst.split(":") - ] + return [val.replace("{COLON}", ":") for val in lst.split(":")] def write_float_matrix(fid, kind, mat): diff --git a/mne/channels/channels.py b/mne/channels/channels.py index 5ed3a8967b3..25eca7ce65a 100644 --- a/mne/channels/channels.py +++ b/mne/channels/channels.py @@ -813,6 +813,7 @@ def interpolate_bads( origin="auto", method=None, exclude=(), + *, verbose=None, ): """Interpolate bad MEG and EEG channels. @@ -834,9 +835,7 @@ def interpolate_bads( .. versionadded:: 0.21 - exclude : list | tuple - The channels to exclude from interpolation. If excluded a bad - channel will stay in bads. + %(exclude_interp)s %(verbose)s Returns diff --git a/mne/epochs.py b/mne/epochs.py index f65b3e87e19..bd2b0c31dab 100644 --- a/mne/epochs.py +++ b/mne/epochs.py @@ -66,12 +66,7 @@ ) from .baseline import _check_baseline, _log_rescale, rescale from .bem import _check_origin -from .channels.channels import ( - InterpolationMixin, - ReferenceMixin, - UpdateChannelsMixin, - interpolate_bads, -) +from .channels.channels import InterpolationMixin, ReferenceMixin, UpdateChannelsMixin from .event import _read_events_fif, make_fixed_length_events, match_event_names from .evoked import EvokedArray from .filter import FilterMixin, _check_fun, detrend @@ -1415,7 +1410,6 @@ def drop_bad(self, reject="existing", flat="existing", verbose=None): self._get_data(out=False, verbose=verbose) return self - @copy_function_doc_to_method_doc(interpolate_bads) def interpolate_bads( self, reset_bads=True, @@ -1423,17 +1417,25 @@ def interpolate_bads( origin="auto", method=None, exclude=(), + *, + interp_chs=None, verbose=None, ): - """Interpolate bad channels per epoch. + """Interpolate bad channels, optionally different channels on each epoch. + + Operates in-place. Parameters ---------- - interp_chs : list - A list containing channels to interpolate for each epoch. + %(reset_bads)s %(mode_interp)s %(origin_interp)s %(method_interp)s + %(exclude_interp)s + interp_chs : list | None + The channels to interpolate on each epoch. Must be the same + length as the epochs and each list entry must be a list of channels + to interpolate for each epoch. %(verbose)s Returns @@ -1443,9 +1445,19 @@ def interpolate_bads( Notes ----- + The ``"MNE"`` method uses minimum-norm projection to a sphere and back. + .. versionadded:: 1.7 """ - _check_preload(self, "epochs.interpolate_bads_per_epoch") + if interp_chs is None: + return super().interpolate_bads( + reset_bads=reset_bads, + mode=mode, + origin=origin, + method=method, + exclude=exclude, + ) + _check_preload(self, "epochs.interpolate_bads") if len(interp_chs) != len(self): raise ValueError( "The length of ``interp_chs`` must be " @@ -1463,8 +1475,16 @@ def interpolate_bads( logger.debug(f"Interpolating {this_interp_chs} on epoch {epoch_idx}") epoch = self[epoch_idx] epoch.info["bads"] = list(this_interp_chs) - epoch.interpolate_bads(mode=mode, origin=origin, method=method) + epoch.interpolate_bads( + reset_bads=epoch_idx == 0, # check once for warning + mode=mode, + origin=origin, + method=method, + exclude=exclude, + ) self._data[epoch_idx] = epoch._data + if epoch_idx == 0: + self.info["bads"] = epoch.info["bads"] return self def drop_log_stats(self, ignore=("IGNORED",)): diff --git a/mne/tests/test_epochs.py b/mne/tests/test_epochs.py index c0782852919..f339c8fe9ca 100644 --- a/mne/tests/test_epochs.py +++ b/mne/tests/test_epochs.py @@ -720,7 +720,7 @@ def test_reject(): assert "is a noop" in log -def test_interpolate_bads_per_epochs(): +def test_interpolate_bads_per_epoch(): """Test interpolating bad channels per epoch.""" raw, events, _ = _get_data() names = raw.ch_names[::5] @@ -743,15 +743,15 @@ def test_interpolate_bads_per_epochs(): preload=True, ) with pytest.raises(ValueError, match="length of ``interp_chs``"): - epochs.interpolate_bads_per_epoch([]) + epochs.interpolate_bads(interp_chs=[]) interp_chs = [tuple() for _ in range(len(epochs))] interp_chs[0] = ("foo",) with pytest.raises(ValueError, match="Channels .* not found"): - epochs.interpolate_bads_per_epoch(interp_chs) + epochs.interpolate_bads(interp_chs=interp_chs) interp_chs[0] = (epochs.ch_names[0],) data_before = epochs.get_data().copy() - data_after = epochs.interpolate_bads_per_epoch(interp_chs).get_data() + data_after = epochs.interpolate_bads(interp_chs=interp_chs).get_data() assert not np.array_equal(data_before[0], data_after[0]) diff --git a/mne/utils/docs.py b/mne/utils/docs.py index d90acbc3a17..27d1b1d1c58 100644 --- a/mne/utils/docs.py +++ b/mne/utils/docs.py @@ -1168,6 +1168,12 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): (below the nasion) and positive Y values (in front of the LPA/RPA). """ +docdict["exclude_interp"] = """ +exclude : list | tuple + The channels to exclude from interpolation. If excluded a bad + channel will stay in bads. +""" + _exclude_spectrum = """\ exclude : list of str | 'bads' Channel names to exclude{}. If ``'bads'``, channels