diff --git a/doc/changes/devel/12334.newfeature.rst b/doc/changes/devel/12334.newfeature.rst new file mode 100644 index 00000000000..da738a25838 --- /dev/null +++ b/doc/changes/devel/12334.newfeature.rst @@ -0,0 +1 @@ +Allow :meth:`mne.Epochs.interpolate_bads` to interpolate channels only on selected epochs for which they are bad using ``interp_chs``, by `Alex Rockhill`_ \ No newline at end of file diff --git a/examples/preprocessing/interpolate_bad_channels.py b/examples/preprocessing/interpolate_bad_channels.py index a56aec7d8f7..969b090d03d 100644 --- a/examples/preprocessing/interpolate_bad_channels.py +++ b/examples/preprocessing/interpolate_bad_channels.py @@ -23,6 +23,8 @@ # sphinx_gallery_thumbnail_number = 2 +import numpy as np + import mne from mne.datasets import sample @@ -48,6 +50,27 @@ ) evoked_interp_mne.plot(exclude=[], picks=("grad", "eeg")) +# %% +# You can also interpolate bad channels per epoch +raw = mne.io.read_raw(meg_path / "sample_audvis_raw.fif") +raw.pick("eeg") # just to speed up execution +events = mne.read_events(meg_path / "sample_audvis_raw-eve.fif") +epochs = mne.Epochs(raw, events=events) + +# try to only remove bad channels on some epochs to save EEG 053 +epochs.drop_bad(reject=dict(eeg=100e-6)) + +interpolate_channels = [ + entry if len(entry) < 5 else tuple() for entry in epochs.drop_log +] +drop_epochs = np.array([len(entry) >= 5 for entry in epochs.drop_log]) +del epochs + +epochs = mne.Epochs(raw, events=events, preload=True) +epochs.interpolate_bads(interp_chs=interpolate_channels) +epochs = epochs[~drop_epochs] +epochs.average().plot() + # %% # References # ---------- diff --git a/mne/channels/channels.py b/mne/channels/channels.py index 325be7350a6..25eca7ce65a 100644 --- a/mne/channels/channels.py +++ b/mne/channels/channels.py @@ -813,6 +813,7 @@ def interpolate_bads( origin="auto", method=None, exclude=(), + *, verbose=None, ): """Interpolate bad MEG and EEG channels. @@ -821,41 +822,20 @@ def interpolate_bads( Parameters ---------- - reset_bads : bool - If True, remove the bads from info. - mode : str - Either ``'accurate'`` or ``'fast'``, determines the quality of the - Legendre polynomial expansion used for interpolation of channels - using the minimum-norm method. - origin : array-like, shape (3,) | str - Origin of the sphere in the head coordinate frame and in meters. - Can be ``'auto'`` (default), which means a head-digitization-based - origin fit. + %(reset_bads)s + %(mode_interp)s + %(origin_interp)s .. versionadded:: 0.17 - method : dict | str | None - Method to use for each channel type. - - - ``"meg"`` channels support ``"MNE"`` (default) and ``"nan"`` - - ``"eeg"`` channels support ``"spline"`` (default), ``"MNE"`` and ``"nan"`` - - ``"fnirs"`` channels support ``"nearest"`` (default) and ``"nan"`` - - None is an alias for:: - - method=dict(meg="MNE", eeg="spline", fnirs="nearest") - - If a :class:`str` is provided, the method will be applied to all channel - types supported and available in the instance. The method ``"nan"`` will - replace the channel data with ``np.nan``. + %(method_interp)s .. warning:: Be careful when using ``method="nan"``; the default value ``reset_bads=True`` may not be what you want. .. versionadded:: 0.21 - exclude : list | tuple - The channels to exclude from interpolation. If excluded a bad - channel will stay in bads. + + %(exclude_interp)s %(verbose)s Returns diff --git a/mne/epochs.py b/mne/epochs.py index 34d942536bd..bd2b0c31dab 100644 --- a/mne/epochs.py +++ b/mne/epochs.py @@ -1410,6 +1410,83 @@ def drop_bad(self, reject="existing", flat="existing", verbose=None): self._get_data(out=False, verbose=verbose) return self + def interpolate_bads( + self, + reset_bads=True, + mode="accurate", + origin="auto", + method=None, + exclude=(), + *, + interp_chs=None, + verbose=None, + ): + """Interpolate bad channels, optionally different channels on each epoch. + + Operates in-place. + + Parameters + ---------- + %(reset_bads)s + %(mode_interp)s + %(origin_interp)s + %(method_interp)s + %(exclude_interp)s + interp_chs : list | None + The channels to interpolate on each epoch. Must be the same + length as the epochs and each list entry must be a list of channels + to interpolate for each epoch. + %(verbose)s + + Returns + ------- + epochs : instance of Epochs + The epochs with bad channels interpolated per epoch. Operates in-place. + + Notes + ----- + The ``"MNE"`` method uses minimum-norm projection to a sphere and back. + + .. versionadded:: 1.7 + """ + if interp_chs is None: + return super().interpolate_bads( + reset_bads=reset_bads, + mode=mode, + origin=origin, + method=method, + exclude=exclude, + ) + _check_preload(self, "epochs.interpolate_bads") + if len(interp_chs) != len(self): + raise ValueError( + "The length of ``interp_chs`` must be " + f"the same as the the number of epochs ({len(self)}), " + f"got {len(interp_chs)}" + ) + for i, this_interp_ch in enumerate(interp_chs): + chs_not_found = [ch for ch in this_interp_ch if ch not in self.ch_names] + if chs_not_found: + raise ValueError( + f"Channels {chs_not_found} not found " f"for interp_chs[{i}]" + ) + for epoch_idx, this_interp_chs in enumerate(interp_chs): + if this_interp_chs: + logger.debug(f"Interpolating {this_interp_chs} on epoch {epoch_idx}") + epoch = self[epoch_idx] + epoch.info["bads"] = list(this_interp_chs) + epoch.interpolate_bads( + reset_bads=epoch_idx == 0, # check once for warning + mode=mode, + origin=origin, + method=method, + exclude=exclude, + ) + self._data[epoch_idx] = epoch._data + if epoch_idx == 0: + self.info["bads"] = epoch.info["bads"] + return self + def drop_log_stats(self, ignore=("IGNORED",)): """Compute the channel stats based on a drop_log from Epochs. diff --git a/mne/tests/test_epochs.py b/mne/tests/test_epochs.py index c68fc7ce6bd..f339c8fe9ca 100644 --- a/mne/tests/test_epochs.py +++ b/mne/tests/test_epochs.py @@ -720,6 +720,41 @@ def test_reject(): assert "is a noop" in log +def test_interpolate_bads_per_epoch(): + """Test interpolating bad channels per epoch.""" + raw, events, _ = _get_data() + names = raw.ch_names[::5] + assert "MEG 2443" in names + raw.pick(names).load_data() + assert "eog" in raw + raw.info.normalize_proj() + picks = np.arange(len(raw.ch_names)) + # cull the list just to contain the relevant event + events = events[events[:, 2] == event_id, :] + assert len(events) == 7 + # test interpolate bads per epoch + epochs = Epochs( + raw, + events, + event_id, + tmin, + tmax, + picks=picks, + preload=True, + ) + with pytest.raises(ValueError, match="length of ``interp_chs``"): + epochs.interpolate_bads(interp_chs=[]) + interp_chs = [tuple() for _ in range(len(epochs))] + interp_chs[0] = ("foo",) + with pytest.raises(ValueError, match="Channels .* not found"): + epochs.interpolate_bads(interp_chs=interp_chs) + + interp_chs[0] = (epochs.ch_names[0],) + data_before = epochs.get_data().copy() + data_after = epochs.interpolate_bads(interp_chs=interp_chs).get_data() + assert not np.array_equal(data_before[0], data_after[0]) + + def test_reject_by_annotations_reject_tmin_reject_tmax(): """Test reject_by_annotations with reject_tmin and reject_tmax defined.""" # 10 seconds of data, event at 2s, bad segment from 1s to 1.5s diff --git a/mne/utils/docs.py b/mne/utils/docs.py index 1fa26fa16dd..27d1b1d1c58 100644 --- a/mne/utils/docs.py +++ b/mne/utils/docs.py @@ -1168,6 +1168,12 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): (below the nasion) and positive Y values (in front of the LPA/RPA). """ +docdict["exclude_interp"] = """ +exclude : list | tuple + The channels to exclude from interpolation. If excluded a bad + channel will stay in bads. +""" + _exclude_spectrum = """\ exclude : list of str | 'bads' Channel names to exclude{}. If ``'bads'``, channels @@ -2221,6 +2227,23 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): forward-backward filtering (via :func:`~scipy.signal.filtfilt`). """ +docdict["method_interp"] = """ +method : dict | str | None + Method to use for each channel type. + + - ``"meg"`` channels support ``"MNE"`` (default) and ``"nan"`` + - ``"eeg"`` channels support ``"spline"`` (default), ``"MNE"`` and ``"nan"`` + - ``"fnirs"`` channels support ``"nearest"`` (default) and ``"nan"`` + + None is an alias for:: + + method=dict(meg="MNE", eeg="spline", fnirs="nearest") + + If a :class:`str` is provided, the method will be applied to all channel + types supported and available in the instance. The method ``"nan"`` will + replace the channel data with ``np.nan``. +""" + docdict["method_kw_psd"] = """\ **method_kw Additional keyword arguments passed to the spectral estimation @@ -2259,6 +2282,13 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): Extraction mode, see Notes. """ +docdict["mode_interp"] = """ +mode : str + Either ``'accurate'`` or ``'fast'``, determines the quality of the + Legendre polynomial expansion used for interpolation of channels + using the minimum-norm method. +""" + docdict["mode_pctf"] = """ mode : None | 'mean' | 'max' | 'svd' Compute summary of PSFs/CTFs across all indices specified in 'idx'. @@ -2658,6 +2688,13 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): The default changed from False in 1.4 to True in 1.5. """ +docdict["origin_interp"] = """ +origin : array-like, shape (3,) | str + Origin of the sphere in the head coordinate frame and in meters. + Can be ``'auto'`` (default), which means a head-digitization-based + origin fit. +""" + docdict["origin_maxwell"] = """ origin : array-like, shape (3,) | str Origin of internal and external multipolar moment space in meters. @@ -3326,6 +3363,11 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): The resolution of the topomap image (number of pixels along each side). """ +docdict["reset_bads"] = """ +reset_bads : bool + If True, remove the bads from info. +""" + docdict["return_pca_vars_pctf"] = """ return_pca_vars : bool Whether or not to return the explained variances across the specified