Skip to content

Commit

Permalink
test scripts update
Browse files Browse the repository at this point in the history
  • Loading branch information
colehank committed Nov 14, 2024
1 parent c47d582 commit 686fda1
Show file tree
Hide file tree
Showing 5 changed files with 352 additions and 118 deletions.
65 changes: 38 additions & 27 deletions mne_icalabel/megnet/_utils.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,55 @@
import numpy as np
from numpy.typing import NDArray


def cart2sph(x, y, z):
"""Convert cartesian coordinates to spherical coordinates."""
def _cart2sph(x, y, z):
xy = np.sqrt(x * x + y * y)
r = np.sqrt(x * x + y * y + z * z)
theta = np.arctan2(y, x)
phi = np.arctan2(z, xy)
return r, theta, phi


def pol2cart(rho, phi):
"""Convert polar coordinates to cartesian coordinates."""
x = rho * np.cos(phi)
y = rho * np.sin(phi)
return x, y


def make_head_outlines(sphere, pos, outlines, clip_origin):
assert isinstance(sphere, np.ndarray)
def _make_head_outlines(
sphere: NDArray,
pos: NDArray,
clip_origin: tuple
) -> dict:
"""a modified version of mne.viz.topomap._make_head_outlines.
This function is used to generate head outlines for topomap plotting.
The difference between this function and the original one is that
head_x and head_y here are scaled by a factor of 1.01 to make topomap
fit the 120x120 pixel size.
Also, removed the ear and nose outlines for not needed in MEGnet.
Parameters
----------
sphere : NDArray
The sphere parameters (x, y, z, radius).
pos : NDArray
The 2D sensor positions.
clip_origin : tuple
The origin of the clipping circle.
Returns
-------
dict
Dictionary containing the head outlines and mask positions.
"""
x, y, _, radius = sphere
del sphere

ll = np.linspace(0, 2 * np.pi, 101)
head_x = np.cos(ll) * radius * 1.01 + x
head_y = np.sin(ll) * radius * 1.01 + y
dx = np.exp(np.arccos(np.deg2rad(12)) * 1j)
dx, _ = dx.real, dx.imag

outlines_dict = dict(head=(head_x, head_y))

mask_scale = 1.0
max_norm = np.linalg.norm(pos, axis=1).max()
mask_scale = max(mask_scale, max_norm * 1.01 / radius)

outlines_dict["mask_pos"] = (mask_scale * head_x, mask_scale * head_y)
mask_scale = max(1.0, np.linalg.norm(pos, axis=1).max() * 1.01 / radius)
clip_radius = radius * mask_scale
outlines_dict["clip_radius"] = (clip_radius,) * 2
outlines_dict["clip_origin"] = clip_origin

outlines = outlines_dict

return outlines
outlines_dict = {
"head": (head_x, head_y),
"mask_pos": (mask_scale * head_x, mask_scale * head_y),
"clip_radius": (clip_radius,) * 2,
"clip_origin": clip_origin,
}
return outlines_dict
111 changes: 57 additions & 54 deletions mne_icalabel/megnet/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,31 @@
from mne.io import BaseRaw
from mne.preprocessing import ICA
from mne.utils import _validate_type, warn
from mne_icalabel.iclabel._utils import _pol2cart
from numpy.typing import NDArray
from PIL import Image
from scipy import interpolate
from scipy.spatial import ConvexHull

from ._utils import cart2sph, pol2cart
from ._utils import _cart2sph, _make_head_outlines


def get_megnet_features(raw: BaseRaw, ica: ICA):
"""Extract time series and topomaps for each ICA component.
MEGNet uses topomaps from BrainStorm exported as 120x120x3 RGB images. Thus, we need
to replicate the 'appearance'/'look' of a BrainStorm topomap.
MEGNet uses topomaps from BrainStorm exported as 120x120x3 RGB images.
Thus, we need to replicate the 'appearance'/'look' of a BrainStorm topomap.
Parameters
----------
raw : Raw.
Raw MEG recording used to fit the ICA decomposition. The raw instance should be
bandpass filtered between 1 and 100 Hz and notch filtered at 50 or 60 Hz to
Raw MEG recording used to fit the ICA decomposition.
The raw instance should be bandpass filtered between
1 and 100 Hz and notch filtered at 50 or 60 Hz to
remove line noise, and downsampled to 250 Hz.
ica : ICA
ICA decomposition of the provided instance. The ICA decomposition
should use the infomax method.
ICA decomposition of the provided instance.
The ICA decomposition hould use the infomax method.
Returns
-------
Expand All @@ -40,73 +42,66 @@ def get_megnet_features(raw: BaseRaw, ica: ICA):
_validate_type(raw, BaseRaw, "raw")
_validate_type(ica, ICA, "ica")
if not any(
ch_type in ["mag", "grad"] for ch_type in raw.get_channel_types(unique=True)
ch_type in ["mag", "grad"] for ch_type in raw.get_channel_types(
unique=True)
):
raise RuntimeError(
"Could not find MEG channels in the provided Raw instance. The MEGnet "
"model was fitted on MEG data and is not suited for other types of "
"channels."
"Could not find MEG channels in the provided Raw instance."
"The MEGnet model was fitted on MEG data and is not"
"suited for other types of channels."
)
if n_samples := raw.get_data().shape[1] < 15000:
if (n_samples := raw.get_data().shape[1]) < 15000:
raise RuntimeError(
f"The provided raw instance has {n_samples} points. MEGnet was designed to "
"classify features extracted from an MEG dataset at least 60 seconds long "
"@ 250 Hz, corresponding to at least. 15 000 samples."
f"The provided raw instance has {n_samples} points. "
"MEGnet was designed to classify features extracted "
"from an MEG dataset at least 60 seconds long @ 250 Hz,"
"corresponding to at least. 15 000 samples."
)
if not np.isclose(raw.info["sfreq"], 250, atol=1e-1):
warn(
"The provided raw instance is not sampled at 250 Hz "
f"(sfreq={raw.info['sfreq']} Hz). "
"MEGnet was designed to classify features extracted from"
"an MEG dataset sampled at 250 Hz "
"(see the 'resample()' method for raw). "
"(see the 'resample()' method for Raw instances). "
"The classification performance might be negatively impacted."
)
if raw.info["highpass"] != 1 or raw.info["lowpass"] != 100:
warn(
"The provided raw instance is not filtered between 1 and 100 Hz. "
"MEGnet was designed to classify features extracted from an MEG dataset "
"bandpass filtered between 1 and 100 Hz (see the 'filter()' method for "
"Raw). The classification performance might be negatively impacted."
"MEGnet was designed to classify features extracted from an MEG "
"dataset bandpass filtered between 1 and 100 Hz"
" (see the 'filter()' method for Raw instances)."
" The classification performance might be negatively impacted."
)
if _check_line_noise(raw):
warn(
"Line noise detected in 50/60 Hz. MEGnet was trained on MEG data without "
"line noise. Please remove line noise before using MEGnet "
"(see the 'notch_filter()' method for Raw instances."
"Line noise detected in 50/60 Hz. MEGnet was trained on"
"MEG data without line noise. Please remove line noise"
"before using MEGnet (see the 'notch_filter()' method"
"for Raw instances)."
)
if ica.method != "infomax":
warn(
f"The provided ICA instance was fitted with a '{ica.method}' algorithm. "
"MEGnet was designed with infomax ICA decompositions. To use the "
"infomax algorithm, use mne.preprocessing.ICA instance with "
f"The provided ICA instance was fitted with a '{ica.method}'"
"algorithm. MEGnet was designed with infomax method."
"To use the it, set mne.preprocessing.ICA instance with "
"the arguments ICA(method='infomax')."
)
if ica.n_components != 20:
warn(
f"The provided ICA instance has {ica.n_components} components. "
"MEGnet was designed with 20 components. "
"use mne.preprocessing.ICA instance with "
"the arguments ICA(n_components=20)."
)

pos_new, outlines = _get_topomaps_data(ica)
topomaps = _get_topomaps(ica, pos_new, outlines)
time_series = ica.get_sources(raw).get_data()
return time_series, topomaps


def _make_head_outlines(sphere: NDArray, pos: NDArray, clip_origin: tuple):
"""Generate head outlines and mask positions for the topomap plot."""
x, y, _, radius = sphere
ll = np.linspace(0, 2 * np.pi, 101)
head_x = np.cos(ll) * radius * 1.01 + x
head_y = np.sin(ll) * radius * 1.01 + y

mask_scale = max(1.0, np.linalg.norm(pos, axis=1).max() * 1.01 / radius)
clip_radius = radius * mask_scale

outlines_dict = {
"head": (head_x, head_y),
"mask_pos": (mask_scale * head_x, mask_scale * head_y),
"clip_radius": (clip_radius,) * 2,
"clip_origin": clip_origin,
}
return outlines_dict


def _get_topomaps_data(ica: ICA):
"""Prepare 2D sensor positions and outlines for topomap plotting."""
mags = mne.pick_types(ica.info, meg="mag")
Expand All @@ -116,15 +111,15 @@ def _get_topomaps_data(ica: ICA):

# Convert to spherical and then to 2D
sph_coords = np.transpose(
cart2sph(
_cart2sph(
channel_locations_3d[:, 0],
channel_locations_3d[:, 1],
channel_locations_3d[:, 2],
)
)
TH, PHI = sph_coords[:, 1], sph_coords[:, 2]
newR = 1 - PHI / np.pi * 2
channel_locations_2d = np.transpose(pol2cart(newR, TH))
channel_locations_2d = np.transpose(_pol2cart(TH, newR))

# Adjust coordinates with convex hull interpolation
hull = ConvexHull(channel_locations_2d)
Expand All @@ -143,10 +138,11 @@ def _get_topomaps_data(ica: ICA):
D = interp_func(TH)

adjusted_R = np.array([min(newR[i] * D[i], 1) for i in range(len(mags))])
Xnew, Ynew = pol2cart(adjusted_R, TH)
Xnew, Ynew = _pol2cart(TH, adjusted_R)
pos_new = np.vstack((Xnew, Ynew)).T

outlines = _make_head_outlines(np.array([0, 0, 0, 1]), pos_new, (0, 0))
outlines = _make_head_outlines(
np.array([0, 0, 0, 1]), pos_new, (0, 0))
return pos_new, outlines


Expand All @@ -158,7 +154,8 @@ def _get_topomaps(ica: ICA, pos_new: NDArray, outlines: dict):

for comp in range(ica.n_components_):
data = components[data_picks, comp]
fig = plt.figure(figsize=(1.3, 1.3), dpi=100, facecolor="black")
fig = plt.figure(
figsize=(1.3, 1.3), dpi=100, facecolor="black")
ax = fig.add_subplot(111)
mnefig, _ = mne.viz.plot_topomap(
data,
Expand Down Expand Up @@ -192,16 +189,22 @@ def _check_line_noise(
raw: BaseRaw, *, neighbor_width: int = 4, threshold_factor: int = 10
) -> bool:
"""Check if line noise is present in the MEG/EEG data."""
if raw.info.get("line_freq", None) is None: # we don't know the line frequency
# we don't know the line frequency
if raw.info.get("line_freq", None) is None:
return False
# validate the primary and first harmonic frequencies
nyquist_freq = raw.info["sfreq"] / 2.0
line_freqs = [raw.info["line_freq"], 2 * raw.info["line_freq"]]
if any(nyquist_freq < lf for lf in line_freqs):
# not raising because if we get here, it means that someone provided a raw with
# a sampling rate extremely low (100 Hz?) and (1) either they missed all
# of the previous warnings encountered or (2) they know what they are doing.
warn("The sampling rate raw.info['sfreq'] is too low to estimate line niose.")
# not raising because if we get here,
# it means that someone provided a raw with
# a sampling rate extremely low (100 Hz?) and (1)
# either they missed all of the previous warnings
# encountered or (2) they know what they are doing.
warn(
"The sampling rate raw.info['sfreq'] is too low"
"to estimate line niose."
)
return False
# compute the power spectrum and retrieve the frequencies of interest
spectrum = raw.compute_psd(picks="meg", exclude="bads")
Expand Down
Loading

0 comments on commit 686fda1

Please sign in to comment.