Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH, MRG] Add interpolation per epoch #12334

Closed
wants to merge 9 commits into from
1 change: 1 addition & 0 deletions doc/changes/devel/12334.newfeature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Allow :meth:`mne.Epochs.interpolate_bads` to interpolate channels only on selected epochs for which they are bad using ``interp_chs``, by `Alex Rockhill`_
23 changes: 23 additions & 0 deletions examples/preprocessing/interpolate_bad_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@

# sphinx_gallery_thumbnail_number = 2

import numpy as np

import mne
from mne.datasets import sample

Expand All @@ -48,6 +50,27 @@
)
evoked_interp_mne.plot(exclude=[], picks=("grad", "eeg"))

# %%
# You can also interpolate bad channels per epoch
raw = mne.io.read_raw(meg_path / "sample_audvis_raw.fif")
raw.pick("eeg") # just to speed up execution
events = mne.read_events(meg_path / "sample_audvis_raw-eve.fif")
epochs = mne.Epochs(raw, events=events)

# try to only remove bad channels on some epochs to save EEG 053
epochs.drop_bad(reject=dict(eeg=100e-6))

interpolate_channels = [
entry if len(entry) < 5 else tuple() for entry in epochs.drop_log
]
drop_epochs = np.array([len(entry) >= 5 for entry in epochs.drop_log])
del epochs

epochs = mne.Epochs(raw, events=events, preload=True)
epochs.interpolate_bads(interp_chs=interpolate_channels)
epochs = epochs[~drop_epochs]
epochs.average().plot()

# %%
# References
# ----------
Expand Down
34 changes: 7 additions & 27 deletions mne/channels/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -813,6 +813,7 @@ def interpolate_bads(
origin="auto",
method=None,
exclude=(),
*,
verbose=None,
):
"""Interpolate bad MEG and EEG channels.
Expand All @@ -821,41 +822,20 @@ def interpolate_bads(

Parameters
----------
reset_bads : bool
If True, remove the bads from info.
mode : str
Either ``'accurate'`` or ``'fast'``, determines the quality of the
Legendre polynomial expansion used for interpolation of channels
using the minimum-norm method.
origin : array-like, shape (3,) | str
Origin of the sphere in the head coordinate frame and in meters.
Can be ``'auto'`` (default), which means a head-digitization-based
origin fit.
%(reset_bads)s
%(mode_interp)s
%(origin_interp)s

.. versionadded:: 0.17
method : dict | str | None
Method to use for each channel type.

- ``"meg"`` channels support ``"MNE"`` (default) and ``"nan"``
- ``"eeg"`` channels support ``"spline"`` (default), ``"MNE"`` and ``"nan"``
- ``"fnirs"`` channels support ``"nearest"`` (default) and ``"nan"``

None is an alias for::

method=dict(meg="MNE", eeg="spline", fnirs="nearest")

If a :class:`str` is provided, the method will be applied to all channel
types supported and available in the instance. The method ``"nan"`` will
replace the channel data with ``np.nan``.
%(method_interp)s

.. warning::
Be careful when using ``method="nan"``; the default value
``reset_bads=True`` may not be what you want.

.. versionadded:: 0.21
exclude : list | tuple
The channels to exclude from interpolation. If excluded a bad
channel will stay in bads.

%(exclude_interp)s
%(verbose)s

Returns
Expand Down
77 changes: 77 additions & 0 deletions mne/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1410,6 +1410,83 @@ def drop_bad(self, reject="existing", flat="existing", verbose=None):
self._get_data(out=False, verbose=verbose)
return self

def interpolate_bads(
self,
reset_bads=True,
mode="accurate",
origin="auto",
method=None,
exclude=(),
*,
interp_chs=None,
verbose=None,
):
"""Interpolate bad channels, optionally different channels on each epoch.

Operates in-place.

Parameters
----------
%(reset_bads)s
%(mode_interp)s
%(origin_interp)s
%(method_interp)s
%(exclude_interp)s
interp_chs : list | None
The channels to interpolate on each epoch. Must be the same
length as the epochs and each list entry must be a list of channels
to interpolate for each epoch.
%(verbose)s

Returns
-------
epochs : instance of Epochs
The epochs with bad channels interpolated per epoch. Operates in-place.

Notes
-----
The ``"MNE"`` method uses minimum-norm projection to a sphere and back.

.. versionadded:: 1.7
"""
if interp_chs is None:
return super().interpolate_bads(
reset_bads=reset_bads,
mode=mode,
origin=origin,
method=method,
exclude=exclude,
)
_check_preload(self, "epochs.interpolate_bads")
if len(interp_chs) != len(self):
raise ValueError(
"The length of ``interp_chs`` must be "
f"the same as the the number of epochs ({len(self)}), "
f"got {len(interp_chs)}"
)
for i, this_interp_ch in enumerate(interp_chs):
chs_not_found = [ch for ch in this_interp_ch if ch not in self.ch_names]
if chs_not_found:
raise ValueError(
f"Channels {chs_not_found} not found " f"for interp_chs[{i}]"
)
for epoch_idx, this_interp_chs in enumerate(interp_chs):
if this_interp_chs:
logger.debug(f"Interpolating {this_interp_chs} on epoch {epoch_idx}")
epoch = self[epoch_idx]
epoch.info["bads"] = list(this_interp_chs)
epoch.interpolate_bads(
reset_bads=epoch_idx == 0, # check once for warning
mode=mode,
origin=origin,
method=method,
exclude=exclude,
)
self._data[epoch_idx] = epoch._data
if epoch_idx == 0:
self.info["bads"] = epoch.info["bads"]
return self

def drop_log_stats(self, ignore=("IGNORED",)):
"""Compute the channel stats based on a drop_log from Epochs.

Expand Down
35 changes: 35 additions & 0 deletions mne/tests/test_epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,6 +720,41 @@ def test_reject():
assert "is a noop" in log


def test_interpolate_bads_per_epoch():
"""Test interpolating bad channels per epoch."""
raw, events, _ = _get_data()
names = raw.ch_names[::5]
assert "MEG 2443" in names
raw.pick(names).load_data()
assert "eog" in raw
raw.info.normalize_proj()
picks = np.arange(len(raw.ch_names))
# cull the list just to contain the relevant event
events = events[events[:, 2] == event_id, :]
assert len(events) == 7
# test interpolate bads per epoch
epochs = Epochs(
raw,
events,
event_id,
tmin,
tmax,
picks=picks,
preload=True,
)
with pytest.raises(ValueError, match="length of ``interp_chs``"):
epochs.interpolate_bads(interp_chs=[])
interp_chs = [tuple() for _ in range(len(epochs))]
interp_chs[0] = ("foo",)
with pytest.raises(ValueError, match="Channels .* not found"):
epochs.interpolate_bads(interp_chs=interp_chs)

interp_chs[0] = (epochs.ch_names[0],)
data_before = epochs.get_data().copy()
data_after = epochs.interpolate_bads(interp_chs=interp_chs).get_data()
assert not np.array_equal(data_before[0], data_after[0])


def test_reject_by_annotations_reject_tmin_reject_tmax():
"""Test reject_by_annotations with reject_tmin and reject_tmax defined."""
# 10 seconds of data, event at 2s, bad segment from 1s to 1.5s
Expand Down
42 changes: 42 additions & 0 deletions mne/utils/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1168,6 +1168,12 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75):
(below the nasion) and positive Y values (in front of the LPA/RPA).
"""

docdict["exclude_interp"] = """
exclude : list | tuple
The channels to exclude from interpolation. If excluded a bad
channel will stay in bads.
"""

_exclude_spectrum = """\
exclude : list of str | 'bads'
Channel names to exclude{}. If ``'bads'``, channels
Expand Down Expand Up @@ -2221,6 +2227,23 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75):
forward-backward filtering (via :func:`~scipy.signal.filtfilt`).
"""

docdict["method_interp"] = """
method : dict | str | None
Method to use for each channel type.

- ``"meg"`` channels support ``"MNE"`` (default) and ``"nan"``
- ``"eeg"`` channels support ``"spline"`` (default), ``"MNE"`` and ``"nan"``
- ``"fnirs"`` channels support ``"nearest"`` (default) and ``"nan"``

None is an alias for::

method=dict(meg="MNE", eeg="spline", fnirs="nearest")

If a :class:`str` is provided, the method will be applied to all channel
types supported and available in the instance. The method ``"nan"`` will
replace the channel data with ``np.nan``.
"""

docdict["method_kw_psd"] = """\
**method_kw
Additional keyword arguments passed to the spectral estimation
Expand Down Expand Up @@ -2259,6 +2282,13 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75):
Extraction mode, see Notes.
"""

docdict["mode_interp"] = """
mode : str
Either ``'accurate'`` or ``'fast'``, determines the quality of the
Legendre polynomial expansion used for interpolation of channels
using the minimum-norm method.
"""

docdict["mode_pctf"] = """
mode : None | 'mean' | 'max' | 'svd'
Compute summary of PSFs/CTFs across all indices specified in 'idx'.
Expand Down Expand Up @@ -2658,6 +2688,13 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75):
The default changed from False in 1.4 to True in 1.5.
"""

docdict["origin_interp"] = """
origin : array-like, shape (3,) | str
Origin of the sphere in the head coordinate frame and in meters.
Can be ``'auto'`` (default), which means a head-digitization-based
origin fit.
"""

docdict["origin_maxwell"] = """
origin : array-like, shape (3,) | str
Origin of internal and external multipolar moment space in meters.
Expand Down Expand Up @@ -3326,6 +3363,11 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75):
The resolution of the topomap image (number of pixels along each side).
"""

docdict["reset_bads"] = """
reset_bads : bool
If True, remove the bads from info.
"""

docdict["return_pca_vars_pctf"] = """
return_pca_vars : bool
Whether or not to return the explained variances across the specified
Expand Down
Loading