diff --git a/mne_icalabel/megnet/_utils.py b/mne_icalabel/megnet/_utils.py index be754851..9e7708dd 100644 --- a/mne_icalabel/megnet/_utils.py +++ b/mne_icalabel/megnet/_utils.py @@ -1,8 +1,8 @@ import numpy as np +from numpy.typing import NDArray -def cart2sph(x, y, z): - """Convert cartesian coordinates to spherical coordinates.""" +def _cart2sph(x, y, z): xy = np.sqrt(x * x + y * y) r = np.sqrt(x * x + y * y + z * z) theta = np.arctan2(y, x) @@ -10,35 +10,46 @@ def cart2sph(x, y, z): return r, theta, phi -def pol2cart(rho, phi): - """Convert polar coordinates to cartesian coordinates.""" - x = rho * np.cos(phi) - y = rho * np.sin(phi) - return x, y - - -def make_head_outlines(sphere, pos, outlines, clip_origin): - assert isinstance(sphere, np.ndarray) +def _make_head_outlines( + sphere: NDArray, + pos: NDArray, + clip_origin: tuple +) -> dict: + """a modified version of mne.viz.topomap._make_head_outlines. + + This function is used to generate head outlines for topomap plotting. + The difference between this function and the original one is that + head_x and head_y here are scaled by a factor of 1.01 to make topomap + fit the 120x120 pixel size. + Also, removed the ear and nose outlines for not needed in MEGnet. + + Parameters + ---------- + sphere : NDArray + The sphere parameters (x, y, z, radius). + pos : NDArray + The 2D sensor positions. + clip_origin : tuple + The origin of the clipping circle. + + Returns + ------- + dict + Dictionary containing the head outlines and mask positions. + + """ x, y, _, radius = sphere - del sphere - ll = np.linspace(0, 2 * np.pi, 101) head_x = np.cos(ll) * radius * 1.01 + x head_y = np.sin(ll) * radius * 1.01 + y - dx = np.exp(np.arccos(np.deg2rad(12)) * 1j) - dx, _ = dx.real, dx.imag - outlines_dict = dict(head=(head_x, head_y)) - - mask_scale = 1.0 - max_norm = np.linalg.norm(pos, axis=1).max() - mask_scale = max(mask_scale, max_norm * 1.01 / radius) - - outlines_dict["mask_pos"] = (mask_scale * head_x, mask_scale * head_y) + mask_scale = max(1.0, np.linalg.norm(pos, axis=1).max() * 1.01 / radius) clip_radius = radius * mask_scale - outlines_dict["clip_radius"] = (clip_radius,) * 2 - outlines_dict["clip_origin"] = clip_origin - - outlines = outlines_dict - return outlines + outlines_dict = { + "head": (head_x, head_y), + "mask_pos": (mask_scale * head_x, mask_scale * head_y), + "clip_radius": (clip_radius,) * 2, + "clip_origin": clip_origin, + } + return outlines_dict diff --git a/mne_icalabel/megnet/features.py b/mne_icalabel/megnet/features.py index d781f8a0..43f17cb3 100644 --- a/mne_icalabel/megnet/features.py +++ b/mne_icalabel/megnet/features.py @@ -6,29 +6,31 @@ from mne.io import BaseRaw from mne.preprocessing import ICA from mne.utils import _validate_type, warn +from mne_icalabel.iclabel._utils import _pol2cart from numpy.typing import NDArray from PIL import Image from scipy import interpolate from scipy.spatial import ConvexHull -from ._utils import cart2sph, pol2cart +from ._utils import _cart2sph, _make_head_outlines def get_megnet_features(raw: BaseRaw, ica: ICA): """Extract time series and topomaps for each ICA component. - MEGNet uses topomaps from BrainStorm exported as 120x120x3 RGB images. Thus, we need - to replicate the 'appearance'/'look' of a BrainStorm topomap. + MEGNet uses topomaps from BrainStorm exported as 120x120x3 RGB images. + Thus, we need to replicate the 'appearance'/'look' of a BrainStorm topomap. Parameters ---------- raw : Raw. - Raw MEG recording used to fit the ICA decomposition. The raw instance should be - bandpass filtered between 1 and 100 Hz and notch filtered at 50 or 60 Hz to + Raw MEG recording used to fit the ICA decomposition. + The raw instance should be bandpass filtered between + 1 and 100 Hz and notch filtered at 50 or 60 Hz to remove line noise, and downsampled to 250 Hz. ica : ICA - ICA decomposition of the provided instance. The ICA decomposition - should use the infomax method. + ICA decomposition of the provided instance. + The ICA decomposition hould use the infomax method. Returns ------- @@ -40,18 +42,20 @@ def get_megnet_features(raw: BaseRaw, ica: ICA): _validate_type(raw, BaseRaw, "raw") _validate_type(ica, ICA, "ica") if not any( - ch_type in ["mag", "grad"] for ch_type in raw.get_channel_types(unique=True) + ch_type in ["mag", "grad"] for ch_type in raw.get_channel_types( + unique=True) ): raise RuntimeError( - "Could not find MEG channels in the provided Raw instance. The MEGnet " - "model was fitted on MEG data and is not suited for other types of " - "channels." + "Could not find MEG channels in the provided Raw instance." + "The MEGnet model was fitted on MEG data and is not" + "suited for other types of channels." ) - if n_samples := raw.get_data().shape[1] < 15000: + if (n_samples := raw.get_data().shape[1]) < 15000: raise RuntimeError( - f"The provided raw instance has {n_samples} points. MEGnet was designed to " - "classify features extracted from an MEG dataset at least 60 seconds long " - "@ 250 Hz, corresponding to at least. 15 000 samples." + f"The provided raw instance has {n_samples} points. " + "MEGnet was designed to classify features extracted " + "from an MEG dataset at least 60 seconds long @ 250 Hz," + "corresponding to at least. 15 000 samples." ) if not np.isclose(raw.info["sfreq"], 250, atol=1e-1): warn( @@ -59,54 +63,45 @@ def get_megnet_features(raw: BaseRaw, ica: ICA): f"(sfreq={raw.info['sfreq']} Hz). " "MEGnet was designed to classify features extracted from" "an MEG dataset sampled at 250 Hz " - "(see the 'resample()' method for raw). " + "(see the 'resample()' method for Raw instances). " "The classification performance might be negatively impacted." ) if raw.info["highpass"] != 1 or raw.info["lowpass"] != 100: warn( "The provided raw instance is not filtered between 1 and 100 Hz. " - "MEGnet was designed to classify features extracted from an MEG dataset " - "bandpass filtered between 1 and 100 Hz (see the 'filter()' method for " - "Raw). The classification performance might be negatively impacted." + "MEGnet was designed to classify features extracted from an MEG " + "dataset bandpass filtered between 1 and 100 Hz" + " (see the 'filter()' method for Raw instances)." + " The classification performance might be negatively impacted." ) if _check_line_noise(raw): warn( - "Line noise detected in 50/60 Hz. MEGnet was trained on MEG data without " - "line noise. Please remove line noise before using MEGnet " - "(see the 'notch_filter()' method for Raw instances." + "Line noise detected in 50/60 Hz. MEGnet was trained on" + "MEG data without line noise. Please remove line noise" + "before using MEGnet (see the 'notch_filter()' method" + "for Raw instances)." ) if ica.method != "infomax": warn( - f"The provided ICA instance was fitted with a '{ica.method}' algorithm. " - "MEGnet was designed with infomax ICA decompositions. To use the " - "infomax algorithm, use mne.preprocessing.ICA instance with " + f"The provided ICA instance was fitted with a '{ica.method}'" + "algorithm. MEGnet was designed with infomax method." + "To use the it, set mne.preprocessing.ICA instance with " "the arguments ICA(method='infomax')." ) + if ica.n_components != 20: + warn( + f"The provided ICA instance has {ica.n_components} components. " + "MEGnet was designed with 20 components. " + "use mne.preprocessing.ICA instance with " + "the arguments ICA(n_components=20)." + ) + pos_new, outlines = _get_topomaps_data(ica) topomaps = _get_topomaps(ica, pos_new, outlines) time_series = ica.get_sources(raw).get_data() return time_series, topomaps -def _make_head_outlines(sphere: NDArray, pos: NDArray, clip_origin: tuple): - """Generate head outlines and mask positions for the topomap plot.""" - x, y, _, radius = sphere - ll = np.linspace(0, 2 * np.pi, 101) - head_x = np.cos(ll) * radius * 1.01 + x - head_y = np.sin(ll) * radius * 1.01 + y - - mask_scale = max(1.0, np.linalg.norm(pos, axis=1).max() * 1.01 / radius) - clip_radius = radius * mask_scale - - outlines_dict = { - "head": (head_x, head_y), - "mask_pos": (mask_scale * head_x, mask_scale * head_y), - "clip_radius": (clip_radius,) * 2, - "clip_origin": clip_origin, - } - return outlines_dict - - def _get_topomaps_data(ica: ICA): """Prepare 2D sensor positions and outlines for topomap plotting.""" mags = mne.pick_types(ica.info, meg="mag") @@ -116,7 +111,7 @@ def _get_topomaps_data(ica: ICA): # Convert to spherical and then to 2D sph_coords = np.transpose( - cart2sph( + _cart2sph( channel_locations_3d[:, 0], channel_locations_3d[:, 1], channel_locations_3d[:, 2], @@ -124,7 +119,7 @@ def _get_topomaps_data(ica: ICA): ) TH, PHI = sph_coords[:, 1], sph_coords[:, 2] newR = 1 - PHI / np.pi * 2 - channel_locations_2d = np.transpose(pol2cart(newR, TH)) + channel_locations_2d = np.transpose(_pol2cart(TH, newR)) # Adjust coordinates with convex hull interpolation hull = ConvexHull(channel_locations_2d) @@ -143,10 +138,11 @@ def _get_topomaps_data(ica: ICA): D = interp_func(TH) adjusted_R = np.array([min(newR[i] * D[i], 1) for i in range(len(mags))]) - Xnew, Ynew = pol2cart(adjusted_R, TH) + Xnew, Ynew = _pol2cart(TH, adjusted_R) pos_new = np.vstack((Xnew, Ynew)).T - outlines = _make_head_outlines(np.array([0, 0, 0, 1]), pos_new, (0, 0)) + outlines = _make_head_outlines( + np.array([0, 0, 0, 1]), pos_new, (0, 0)) return pos_new, outlines @@ -158,7 +154,8 @@ def _get_topomaps(ica: ICA, pos_new: NDArray, outlines: dict): for comp in range(ica.n_components_): data = components[data_picks, comp] - fig = plt.figure(figsize=(1.3, 1.3), dpi=100, facecolor="black") + fig = plt.figure( + figsize=(1.3, 1.3), dpi=100, facecolor="black") ax = fig.add_subplot(111) mnefig, _ = mne.viz.plot_topomap( data, @@ -192,16 +189,22 @@ def _check_line_noise( raw: BaseRaw, *, neighbor_width: int = 4, threshold_factor: int = 10 ) -> bool: """Check if line noise is present in the MEG/EEG data.""" - if raw.info.get("line_freq", None) is None: # we don't know the line frequency + # we don't know the line frequency + if raw.info.get("line_freq", None) is None: return False # validate the primary and first harmonic frequencies nyquist_freq = raw.info["sfreq"] / 2.0 line_freqs = [raw.info["line_freq"], 2 * raw.info["line_freq"]] if any(nyquist_freq < lf for lf in line_freqs): - # not raising because if we get here, it means that someone provided a raw with - # a sampling rate extremely low (100 Hz?) and (1) either they missed all - # of the previous warnings encountered or (2) they know what they are doing. - warn("The sampling rate raw.info['sfreq'] is too low to estimate line niose.") + # not raising because if we get here, + # it means that someone provided a raw with + # a sampling rate extremely low (100 Hz?) and (1) + # either they missed all of the previous warnings + # encountered or (2) they know what they are doing. + warn( + "The sampling rate raw.info['sfreq'] is too low" + "to estimate line niose." + ) return False # compute the power spectrum and retrieve the frequencies of interest spectrum = raw.compute_psd(picks="meg", exclude="bads") diff --git a/mne_icalabel/megnet/label_components.py b/mne_icalabel/megnet/label_components.py index 98512014..e15d02cd 100644 --- a/mne_icalabel/megnet/label_components.py +++ b/mne_icalabel/megnet/label_components.py @@ -11,46 +11,41 @@ _MODEL_PATH: str = files("mne_icalabel.megnet") / "assets" / "megnet.onnx" -def megnet_label_components(raw: BaseRaw, ica: ICA) -> dict: +def megnet_label_components(raw: BaseRaw, ica: ICA) -> NDArray: """Label the provided ICA components with the MEGnet neural network. Parameters ---------- raw : Raw - Raw MEG recording used to fit the ICA decomposition. The raw instance should be - bandpass filtered between 1 and 100 Hz and notch filtered at 50 or 60 Hz to - remove line noise, and downsampled to 250 Hz. + Raw MEG recording used to fit the ICA decomposition. + The raw instance should be bandpass filtered between 1 and 100 Hz + and notch filtered at 50 or 60 Hz to remove line noise, + and downsampled to 250 Hz. ica : ICA - ICA decomposition of the provided instance. The ICA decomposition - should use the infomax method. + ICA decomposition of the provided instance. + The ICA decomposition should use the infomax method. Returns ------- - dict - Dictionary with the following keys: - - 'y_pred_proba' : list of float - The predicted probabilities for each component. - - 'labels' : list of str - The predicted labels for each component. + labels_pred_proba : numpy.ndarray of shape (n_components, n_classes) + The estimated corresponding predicted probabilities of output classes + for each independent component. Columns are ordered with + 'brain/other', 'eye movement', 'heart beat', 'eye blink', """ time_series, topomaps = get_megnet_features(raw, ica) # sanity-checks - assert time_series.shape[0] == topomaps.shape[0] # number of time-series <-> topos - assert topomaps.shape[1:] == (120, 120, 3) # topos are images of shape 120x120x3 - assert 15000 <= time_series.shape[1] # minimum time-series length + # number of time-series <-> topos + assert time_series.shape[0] == topomaps.shape[0] + # topos are images of shape 120x120x3 + assert topomaps.shape[1:] == (120, 120, 3) + # minimum time-series length + assert 15000 <= time_series.shape[1] session = ort.InferenceSession(_MODEL_PATH) - predictions_vote = _chunk_predicting(session, time_series, topomaps) - - all_labels = ["brain/other", "eye movement", "heart", "eye blink"] - # megnet_labels = ['NA', 'EB', 'SA', 'CA'] - result = predictions_vote[:, 0, :] - labels = [all_labels[i] for i in result.argmax(axis=1)] - proba = [result[i, result[i].argmax()] for i in range(result.shape[0])] - - return {"y_pred_proba": proba, "labels": labels} + labels_pred_proba = _chunk_predicting(session, time_series, topomaps) + return labels_pred_proba[:, 0, :] def _chunk_predicting( @@ -72,7 +67,9 @@ def _chunk_predicting( chunk_votes = {start: 0 for start in start_times} for t in range(time_len): - in_chunks = [start <= t < start + chunk_len for start in start_times] + in_chunks = [ + start <= t < start + chunk_len for start in start_times + ] # how many chunks the time point is in num_chunks = np.sum(in_chunks) for start_time, is_in_chunk in zip(start_times, in_chunks): @@ -82,18 +79,24 @@ def _chunk_predicting( weighted_predictions = {} for start_time in chunk_votes.keys(): onnx_inputs = { - session.get_inputs()[0].name: np.expand_dims(comp_map, 0).astype( - np.float32 - ), - session.get_inputs()[1].name: np.expand_dims( - np.expand_dims(comp_series[start_time : start_time + chunk_len], 0), + session.get_inputs()[0] + .name: np.expand_dims(comp_map, 0) + .astype(np.float32), + session.get_inputs()[1] + .name: np.expand_dims( + np.expand_dims( + comp_series[start_time: start_time + chunk_len], 0), -1, - ).astype(np.float32), + ) + .astype(np.float32), } prediction = session.run(None, onnx_inputs)[0] - weighted_predictions[start_time] = prediction * chunk_votes[start_time] + weighted_predictions[start_time] = ( + prediction * chunk_votes[start_time] + ) - comp_prediction = np.stack(list(weighted_predictions.values())).mean(axis=0) + comp_prediction = np.stack( + list(weighted_predictions.values())).mean(axis=0) comp_prediction /= comp_prediction.sum() predction_vote.append(comp_prediction) diff --git a/mne_icalabel/megnet/tests/test_features.py b/mne_icalabel/megnet/tests/test_features.py index f9bce613..a8ed7a04 100644 --- a/mne_icalabel/megnet/tests/test_features.py +++ b/mne_icalabel/megnet/tests/test_features.py @@ -2,8 +2,8 @@ import pytest from mne import create_info from mne.io import RawArray - -from mne_icalabel.megnet.features import _check_line_noise +from mne.preprocessing import ICA +from mne_icalabel.megnet.features import _check_line_noise, get_megnet_features @pytest.fixture @@ -13,7 +13,8 @@ def raw_with_line_noise(): data1 = np.sin(2 * np.pi * 10 * times) + np.sin(2 * np.pi * 30 * times) data2 = np.sin(2 * np.pi * 30 * times) + np.sin(2 * np.pi * 80 * times) data = np.vstack([data1, data2]) - info = create_info(ch_names=["10-30", "30-80"], sfreq=1000, ch_types="mag") + info = create_info( + ch_names=["10-30", "30-80"], sfreq=1000, ch_types="mag") return RawArray(data, info) @@ -23,10 +24,155 @@ def test_check_line_noise(raw_with_line_noise): # 50 Hz is absent from both channels raw_with_line_noise.info["line_freq"] = 50 assert not _check_line_noise(raw_with_line_noise) - # 10 and 80 Hz are present on one channel each, while 30 Hz is present on both + # 10 and 80 Hz are present on one channel each, + # while 30 Hz is present on both raw_with_line_noise.info["line_freq"] = 30 assert _check_line_noise(raw_with_line_noise) raw_with_line_noise.info["line_freq"] = 80 assert _check_line_noise(raw_with_line_noise) raw_with_line_noise.info["line_freq"] = 10 assert _check_line_noise(raw_with_line_noise) + + +def create_raw_ica( + n_channels=20, + sfreq=250, + ch_type="mag", + n_components=20, + filter_range=(1, 100), + ica_method="infomax", + ntime=None, +): + n_times = sfreq * 60 if ntime is None else ntime + data = np.random.randn(n_channels, n_times) + ch_names = [f"MEG {i+1}" for i in range(n_channels)] + + # Create valid channel loc for feature extraction + channel_locs = np.random.randn(n_channels, 3) + channel_locs[:, 0] += 0.1 + channel_locs[:, 1] += 0.1 + channel_locs[:, 2] += 0.1 + + info = create_info( + ch_names=ch_names, sfreq=sfreq, ch_types=ch_type) + for i, loc in enumerate(channel_locs): + info["chs"][i]["loc"][:3] = loc + + raw = RawArray(data, info) + raw.filter(*filter_range) + + # fastica can not converge with the current data + # so we use infomax in computation + # but set ica_method after fitting for testing + ica = ICA(n_components=n_components, method="infomax") + ica.fit(raw) + ica.method = ica_method + + return raw, ica + + +@pytest.fixture +def raw_ica_valid(): + """Raw instance with valid parameters.""" + raw, ica = create_raw_ica() + return raw, ica + + +def test_get_megnet_features(raw_ica_valid): + """test whether the function returns the correct features.""" + time_series, topomaps = get_megnet_features(*raw_ica_valid) + n_components = raw_ica_valid[1].n_components + n_times = raw_ica_valid[0].times.shape[0] + + assert time_series.shape == (n_components, n_times) + assert topomaps.shape == (n_components, 120, 120, 3) + + +@pytest.fixture +def raw_ica_invalid_channel(): + """Raw instance with invalid channel type.""" + raw, ica = create_raw_ica(ch_type="eeg") + return raw, ica + + +@pytest.fixture +def raw_ica_invalid_sfreq(): + """Raw instance with invalid sampling frequency.""" + raw, ica = create_raw_ica(sfreq=600) + return raw, ica + + +@pytest.fixture +def raw_ica_invalid_time(): + """Raw instance with invalid time points.""" + raw, ica = create_raw_ica(ntime=2500) + return raw, ica + + +@pytest.fixture +def raw_ica_invalid_filter(): + """Raw instance with invalid filter range.""" + raw, ica = create_raw_ica(filter_range=(0.1, 100)) + return raw, ica + + +@pytest.fixture +def raw_ica_invalid_ncomp(): + """Raw instance with invalid number of ICA components.""" + raw, ica = create_raw_ica(n_components=10) + return raw, ica + + +@pytest.fixture +def raw_ica_invalid_method(): + """Raw instance with invalid ICA method.""" + raw, ica = create_raw_ica(ica_method="fastica") + return raw, ica + + +def test_get_megnet_features_invalid( + raw_ica_invalid_channel, + raw_ica_invalid_time, + raw_ica_invalid_sfreq, + raw_ica_invalid_filter, + raw_ica_invalid_ncomp, + raw_ica_invalid_method, +): + """test whether the function raises the correct exceptions""" + test_cases = [ + (raw_ica_invalid_channel, RuntimeError, "Could not find MEG channels"), + ( + raw_ica_invalid_time, + RuntimeError, + "The provided raw instance has 2500 points.", + ), + ( + raw_ica_invalid_sfreq, + RuntimeWarning, + "The provided raw instance is not sampled at 250 Hz", + ), + ( + raw_ica_invalid_filter, + RuntimeWarning, + "The provided raw instance is not filtered between 1 and 100 Hz", + ), + ( + raw_ica_invalid_ncomp, + RuntimeWarning, + "The provided ICA instance has 10 components", + ), + ( + raw_ica_invalid_method, + RuntimeWarning, + "The provided ICA instance was fitted with a 'fastica' algorithm", + ), + ] + + for raw_ica_fixture, exc_type, msg in test_cases: + raw, ica = raw_ica_fixture + if exc_type == RuntimeError: + with pytest.raises(exc_type, match=msg): + get_megnet_features(raw, ica) + elif exc_type == RuntimeWarning: + with pytest.warns(exc_type, match=msg): + get_megnet_features(raw, ica) diff --git a/mne_icalabel/megnet/tests/test_label_components.py b/mne_icalabel/megnet/tests/test_label_components.py new file mode 100644 index 00000000..cd516083 --- /dev/null +++ b/mne_icalabel/megnet/tests/test_label_components.py @@ -0,0 +1,71 @@ +from unittest.mock import MagicMock + +import mne +import numpy as np +import onnxruntime as ort +import pytest +from mne_icalabel.megnet.label_components import ( + _chunk_predicting, + _get_chunk_start, + megnet_label_components, +) + + +@pytest.fixture +def raw_ica(): + sample_dir = mne.datasets.sample.data_path() + sample_fname = sample_dir / "MEG" / "sample" / "sample_audvis_raw.fif" + + raw = mne.io.read_raw_fif(sample_fname).pick("mag") + raw.load_data() + raw.resample(250) + raw.notch_filter(60) + raw.filter(1, 100) + + ica = mne.preprocessing.ICA( + n_components=20, + method="infomax", + random_state=88) + ica.fit(raw) + + return raw, ica + + +def test_megnet_label_components(raw_ica): + """test whether the function returns the correct artifact index""" + real_atrifact_idx = [0, 3, 5] # heart beat, eye movement, heart beat + prob = megnet_label_components(*raw_ica) + this_atrifact_idx = list(np.nonzero(prob.argmax(axis=1))[0]) + assert this_atrifact_idx == real_atrifact_idx + + +def test_get_chunk_start(): + """test whether the function returns the correct start times""" + input_len = 10000 + chunk_len = 3000 + overlap_len = 750 + + start_times = _get_chunk_start(input_len, chunk_len, overlap_len) + + assert len(start_times) == 4 + assert start_times == [0, 2250, 4500, 6750] + + +def test_chunk_predicting(): + """test whether MEGnet's chunk volte algorithm returns the correct shape""" + time_series = np.random.rand(5, 10000) + spatial_maps = np.random.rand(5, 120, 120, 3) + + mock_session = MagicMock(spec=ort.InferenceSession) + mock_session.run.return_value = [np.random.rand(4)] + + predictions = _chunk_predicting( + mock_session, + time_series, + spatial_maps, + chunk_len=3000, + overlap_len=750 + ) + + assert predictions.shape == (5, 4) + assert isinstance(predictions, np.ndarray)