Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix a bug for ica.get_sources() and add a corresponding test #13068

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/changes/devel/13068.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed ICA getting sources for concatenated raw instances, by :newcontrib:`Beige Jin`.
BeiGeJin marked this conversation as resolved.
Show resolved Hide resolved
1 change: 1 addition & 0 deletions doc/changes/names.inc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
.. _Ashley Drew: https://github.com/ashdrew
.. _Asish Panda: https://github.com/kaichogami
.. _Austin Hurst: https://github.com/a-hurst
.. _Beige Jin: https://github.com/BeiGeJin
.. _Ben Beasley: https://github.com/musicinmybrain
.. _Britta Westner: https://britta-wstnr.github.io
.. _Bruno Nicenboim: https://bnicenboim.github.io
Expand Down
6 changes: 4 additions & 2 deletions mne/preprocessing/ica.py
Original file line number Diff line number Diff line change
Expand Up @@ -1295,8 +1295,10 @@ def _sources_as_raw(self, raw, add_channels, start, stop):
picks = pick_channels(raw.ch_names, add_channels)
data_ = np.concatenate([data_, raw.get_data(picks, start=start, stop=stop)])
out._data = data_
out._first_samps = [out.first_samp]
out._last_samps = [out.last_samp]
out_first_samp = out.first_samp
out_last_samp = out.last_samp
out._first_samps = [out_first_samp]
out._last_samps = [out_last_samp]
out.filenames = [None]
out.preload = True
out._projector = None
Expand Down
19 changes: 19 additions & 0 deletions mne/preprocessing/tests/test_ica.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
EpochsArray,
EvokedArray,
Info,
concatenate_raws,
create_info,
make_ad_hoc_cov,
pick_channels_regexp,
Expand Down Expand Up @@ -1719,3 +1720,21 @@ def test_ica_ch_types(ch_type):
for inst in [raw, epochs, evoked]:
ica.apply(inst)
ica.get_sources(inst)


@testing.requires_testing_data
def test_ica_get_sources_concatenated():
"""Test ICA get_sources method with concatenated raws."""
# load data
raw = read_raw_fif(raw_fname).crop(0, 3).load_data() # raw has 3 seconds of data
# create concatenated raw instances
raw_concat = concatenate_raws(
[raw.copy(), raw.copy()]
) # raw_concat has 6 seconds of data
# do ICA
ica = ICA(n_components=2, max_iter=2)
with _record_warnings(), pytest.warns(UserWarning, match="did not converge"):
ica.fit(raw_concat)
# get sources
raw_sources = ica.get_sources(raw_concat) # but this only has 3 seconds of data
assert raw_concat.n_times == raw_sources.n_times # this will fail