-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add MEGnet to make MNE-ICALabel work on MEG data #207
Open
colehank
wants to merge
33
commits into
mne-tools:main
Choose a base branch
from
colehank:megnet
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 13 commits
Commits
Show all changes
33 commits
Select commit
Hold shift + click to select a range
ec28e4f
add megnet
colehank 6f272b1
add megnet
colehank b3433c8
double check
colehank 34c2f31
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 5b7dc9c
bug fix
colehank 96ed02d
Merge branch 'megnet' of https://github.com/colehank/mne-icalabel int…
colehank 989cb40
topomaps plot modify & bug fix
colehank bc64aa2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 8af24b1
bug fix
colehank 8f5e0e6
bug fix
colehank 5aabe3f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] bd7f8cc
bug fix
colehank 59aedfb
bug fix
colehank 067849c
Merge branch 'main' into megnet
colehank a0da5ee
bug fix
colehank 143df13
:q!Merge branch 'megnet' of https://github.com/colehank/mne-icalabel …
colehank 58a719a
more validation of raw obejct
colehank b89c864
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 8465017
bug fix
colehank a0e526d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 40b1074
bug fix
colehank 49f39d1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 4f2d43a
fix model path discovery and include assets in package
mscheltienne bbce3cc
improve docstrings
mscheltienne 19e0260
simplify and test validation of line noise
mscheltienne c47d582
clean-up utils
mscheltienne 686fda1
test scripts update
colehank c5897e1
Merge branch 'main' into megnet
colehank 88d0619
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 7d0a7e2
Merge branch 'main' into megnet
colehank 25824fb
bug fix
colehank 036599b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 30cc94a
bug fix
colehank File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# %% | ||
import numpy as np | ||
|
||
|
||
# Conversion functions | ||
def cart2sph(x, y, z): | ||
xy = np.sqrt(x * x + y * y) | ||
r = np.sqrt(x * x + y * y + z * z) | ||
theta = np.arctan2(y, x) | ||
phi = np.arctan2(z, xy) | ||
return r, theta, phi | ||
|
||
|
||
def pol2cart(rho, phi): | ||
x = rho * np.cos(phi) | ||
y = rho * np.sin(phi) | ||
return x, y | ||
|
||
|
||
def make_head_outlines(sphere, pos, outlines, clip_origin): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you write a docstring for the function? |
||
assert isinstance(sphere, np.ndarray) | ||
x, y, _, radius = sphere | ||
del sphere | ||
|
||
ll = np.linspace(0, 2 * np.pi, 101) | ||
head_x = np.cos(ll) * radius * 1.01 + x | ||
head_y = np.sin(ll) * radius * 1.01 + y | ||
dx = np.exp(np.arccos(np.deg2rad(12)) * 1j) | ||
dx, _ = dx.real, dx.imag | ||
|
||
outlines_dict = dict(head=(head_x, head_y)) | ||
|
||
mask_scale = 1.0 | ||
max_norm = np.linalg.norm(pos, axis=1).max() | ||
mask_scale = max(mask_scale, max_norm * 1.01 / radius) | ||
|
||
outlines_dict["mask_pos"] = (mask_scale * head_x, mask_scale * head_y) | ||
clip_radius = radius * mask_scale | ||
outlines_dict["clip_radius"] = (clip_radius,) * 2 | ||
outlines_dict["clip_origin"] = clip_origin | ||
|
||
outlines = outlines_dict | ||
|
||
return outlines |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
import io | ||
|
||
import matplotlib.pyplot as plt | ||
import mne # type: ignore | ||
import numpy as np | ||
from mne.io import BaseRaw # type: ignore | ||
from mne.preprocessing import ICA # type: ignore | ||
from mne.utils import warn # type: ignore | ||
from numpy.typing import NDArray | ||
from PIL import Image | ||
from scipy import interpolate # type: ignore | ||
from scipy.spatial import ConvexHull # type: ignore | ||
|
||
from ._utils import cart2sph, pol2cart | ||
|
||
|
||
def get_megnet_features(raw: BaseRaw, ica: ICA): | ||
"""Extract time series and topomaps for each ICA component. | ||
|
||
the main work is focused on making BrainStorm-like topomaps | ||
which trained the MEGnet. | ||
|
||
Parameters | ||
---------- | ||
raw : BaseRaw | ||
The raw MEG data. The raw instance should have 250 Hz | ||
sampling frequency and more than 60 seconds. | ||
ica : ICA | ||
The ICA object containing the independent components. | ||
|
||
Returns | ||
------- | ||
time_series : np.ndarray | ||
mscheltienne marked this conversation as resolved.
Show resolved
Hide resolved
|
||
The time series for each ICA component. | ||
topomaps : np.ndarray | ||
mscheltienne marked this conversation as resolved.
Show resolved
Hide resolved
|
||
The topomaps for each ICA component | ||
|
||
""" | ||
if "meg" not in raw: | ||
raise RuntimeError( | ||
"Could not find MEG channels in the provided " | ||
"Raw instance. The MEGnet model was fitted on" | ||
"MEG data and is not suited for other types of channels." | ||
) | ||
|
||
if raw.times[-1] < 60: | ||
raise RuntimeError( | ||
f"The provided raw instance has {raw.times[-1]} seconds. " | ||
"MEGnet was designed to classify features extracted from " | ||
"an MEG datasetat least 60 seconds long. " | ||
) | ||
|
||
if not np.isclose(raw.info["sfreq"], 250, atol=1e-1): | ||
warn( | ||
"The provided raw instance is not sampled at 250 Hz" | ||
f"(sfreq={raw.info['sfreq']} Hz). " | ||
"MEGnet was designed to classify features extracted from" | ||
"an MEG dataset sampled at 250 Hz" | ||
"(see the 'resample()' method for raw)." | ||
"The classification performance might be negatively impacted." | ||
) | ||
mscheltienne marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
mscheltienne marked this conversation as resolved.
Show resolved
Hide resolved
|
||
pos_new, outlines = _get_topomaps_data(ica) | ||
topomaps = _get_topomaps(ica, pos_new, outlines) | ||
time_series = ica.get_sources(raw)._data | ||
|
||
return time_series, topomaps | ||
|
||
|
||
def _make_head_outlines(sphere: NDArray, pos: NDArray, clip_origin: tuple): | ||
"""Generate head outlines and mask positions for the topomap plot.""" | ||
x, y, _, radius = sphere | ||
ll = np.linspace(0, 2 * np.pi, 101) | ||
head_x = np.cos(ll) * radius * 1.01 + x | ||
head_y = np.sin(ll) * radius * 1.01 + y | ||
|
||
mask_scale = max(1.0, np.linalg.norm(pos, axis=1).max() * 1.01 / radius) | ||
clip_radius = radius * mask_scale | ||
|
||
outlines_dict = { | ||
"head": (head_x, head_y), | ||
"mask_pos": (mask_scale * head_x, mask_scale * head_y), | ||
"clip_radius": (clip_radius,) * 2, | ||
"clip_origin": clip_origin, | ||
} | ||
return outlines_dict | ||
|
||
|
||
def _get_topomaps_data(ica: ICA): | ||
"""Prepare 2D sensor positions and outlines for topomap plotting.""" | ||
mags = mne.pick_types(ica.info, meg="mag") | ||
channel_info = ica.info["chs"] | ||
loc_3d = [channel_info[i]["loc"][0:3] for i in mags] | ||
channel_locations_3d = np.array(loc_3d) | ||
|
||
# Convert to spherical and then to 2D | ||
sph_coords = np.transpose( | ||
cart2sph( | ||
channel_locations_3d[:, 0], | ||
channel_locations_3d[:, 1], | ||
channel_locations_3d[:, 2], | ||
) | ||
) | ||
TH, PHI = sph_coords[:, 1], sph_coords[:, 2] | ||
newR = 1 - PHI / np.pi * 2 | ||
channel_locations_2d = np.transpose(pol2cart(newR, TH)) | ||
|
||
# Adjust coordinates with convex hull interpolation | ||
hull = ConvexHull(channel_locations_2d) | ||
border_indices = hull.vertices | ||
Dborder = 1 / newR[border_indices] | ||
|
||
funcTh = np.hstack( | ||
[ | ||
TH[border_indices] - 2 * np.pi, | ||
TH[border_indices], | ||
TH[border_indices] + 2 * np.pi, | ||
] | ||
) | ||
funcD = np.hstack((Dborder, Dborder, Dborder)) | ||
interp_func = interpolate.interp1d(funcTh, funcD) | ||
D = interp_func(TH) | ||
|
||
adjusted_R = np.array([min(newR[i] * D[i], 1) for i in range(len(mags))]) | ||
Xnew, Ynew = pol2cart(adjusted_R, TH) | ||
pos_new = np.vstack((Xnew, Ynew)).T | ||
|
||
outlines = _make_head_outlines(np.array([0, 0, 0, 1]), pos_new, (0, 0)) | ||
return pos_new, outlines | ||
|
||
|
||
def _get_topomaps(ica: ICA, pos_new: NDArray, outlines: dict): | ||
"""Generate topomap images for each ICA component.""" | ||
topomaps = [] | ||
data_picks, _, _, _, _, _, _ = mne.viz.topomap._prepare_topomap_plot( | ||
ica, ch_type="mag" | ||
) | ||
mscheltienne marked this conversation as resolved.
Show resolved
Hide resolved
|
||
components = ica.get_components() | ||
|
||
for comp in range(ica.n_components_): | ||
data = components[data_picks, comp] | ||
fig = plt.figure(figsize=(1.3, 1.3), dpi=100, facecolor="black") | ||
ax = fig.add_subplot(111) | ||
mnefig, _ = mne.viz.plot_topomap( | ||
data, | ||
pos_new, | ||
sensors=False, | ||
outlines=outlines, | ||
extrapolate="head", | ||
sphere=[0, 0, 0, 1], | ||
contours=0, | ||
res=120, | ||
axes=ax, | ||
show=False, | ||
cmap="bwr", | ||
) | ||
img_buf = io.BytesIO() | ||
mnefig.figure.savefig( | ||
img_buf, format="png", dpi=120, bbox_inches="tight", pad_inches=0 | ||
) | ||
img_buf.seek(0) | ||
rgba_image = Image.open(img_buf) | ||
rgb_image = rgba_image.convert("RGB") | ||
img_buf.close() | ||
plt.close(fig) | ||
mscheltienne marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
topomaps.append(np.array(rgb_image)) | ||
|
||
return np.array(topomaps) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
import os.path as op | ||
|
||
import numpy as np | ||
import onnxruntime as ort | ||
from mne.io import BaseRaw | ||
from mne.preprocessing import ICA | ||
from numpy.typing import NDArray | ||
|
||
from .features import get_megnet_features | ||
|
||
|
||
def megnet_label_components( | ||
raw: BaseRaw, | ||
ica: ICA, | ||
model_path: str = op.join("assets", "network", "megnet.onnx"), | ||
) -> dict: | ||
"""Label the provided ICA components with the MEGnet neural network. | ||
|
||
Parameters | ||
---------- | ||
raw : BaseRaw | ||
The raw MEG data. | ||
ica : mne.preprocessing.ICA | ||
The ICA data. | ||
model_path : str | ||
Path to the ONNX model file. | ||
|
||
Returns | ||
------- | ||
dict | ||
Dictionary with the following keys: | ||
- 'y_pred_proba' : list of float | ||
The predicted probabilities for each component. | ||
- 'labels' : list of str | ||
The predicted labels for each component. | ||
|
||
""" | ||
time_series, topomaps = get_megnet_features(raw, ica) | ||
|
||
assert ( | ||
time_series.shape[0] == topomaps.shape[0] | ||
), "The number of time series should match the number of spatial topomaps." | ||
assert topomaps.shape[1:] == ( | ||
120, | ||
120, | ||
3, | ||
), "The topomaps should have dimensions [N, 120, 120, 3]." | ||
assert ( | ||
time_series.shape[1] >= 15000 | ||
), "The time series must be at least 15000 samples long." | ||
|
||
session = ort.InferenceSession(model_path) | ||
predictions_vote = _chunk_predicting(session, time_series, topomaps) | ||
|
||
all_labels = ["brain/other", "eye movement", "heart", "eye blink"] | ||
# megnet_labels = ['NA', 'EB', 'SA', 'CA'] | ||
result = predictions_vote[:, 0, :] | ||
labels = [all_labels[i] for i in result.argmax(axis=1)] | ||
proba = [result[i, result[i].argmax()] for i in range(result.shape[0])] | ||
|
||
return {"y_pred_proba": proba, "labels": labels} | ||
|
||
|
||
def _chunk_predicting( | ||
session: ort.InferenceSession, | ||
time_series: NDArray, | ||
spatial_maps: NDArray, | ||
chunk_len=15000, | ||
overlap_len=3750, | ||
) -> NDArray: | ||
"""MEGnet's chunk volte algorithm.""" | ||
predction_vote = [] | ||
|
||
for comp_series, comp_map in zip(time_series, spatial_maps): | ||
time_len = comp_series.shape[0] | ||
start_times = _get_chunk_start(time_len, chunk_len, overlap_len) | ||
|
||
if start_times[-1] + chunk_len <= time_len: | ||
start_times.append(time_len - chunk_len) | ||
|
||
chunk_votes = {start: 0 for start in start_times} | ||
for t in range(time_len): | ||
in_chunks = [start <= t < start + chunk_len for start in start_times] | ||
# how many chunks the time point is in | ||
num_chunks = np.sum(in_chunks) | ||
for start_time, is_in_chunk in zip(start_times, in_chunks): | ||
if is_in_chunk: | ||
chunk_votes[start_time] += 1.0 / num_chunks | ||
|
||
weighted_predictions = {} | ||
for start_time in chunk_votes.keys(): | ||
onnx_inputs = { | ||
session.get_inputs()[0].name: np.expand_dims(comp_map, 0).astype( | ||
np.float32 | ||
), | ||
session.get_inputs()[1].name: np.expand_dims( | ||
np.expand_dims(comp_series[start_time : start_time + chunk_len], 0), | ||
-1, | ||
).astype(np.float32), | ||
} | ||
prediction = session.run(None, onnx_inputs)[0] | ||
weighted_predictions[start_time] = prediction * chunk_votes[start_time] | ||
|
||
comp_prediction = np.stack(list(weighted_predictions.values())).mean(axis=0) | ||
comp_prediction /= comp_prediction.sum() | ||
predction_vote.append(comp_prediction) | ||
|
||
return np.stack(predction_vote) | ||
mscheltienne marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
def _get_chunk_start( | ||
input_len: int, chunk_len: int = 15000, overlap_len: int = 3750 | ||
) -> list: | ||
"""Calculate start times for time series chunks with overlap.""" | ||
start_times = [] | ||
start_time = 0 | ||
while start_time + chunk_len <= input_len: | ||
start_times.append(start_time) | ||
start_time += chunk_len - overlap_len | ||
return start_times |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These functions are also defined for ICLabel, so I wonder if we can pull them out for a general utility functions related to geometry.
mne-icalabel/mne_icalabel/iclabel/_utils.py
Line 97 in b2fc448
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pol2cart
is indeed a duplicate,cart2sph
returns the element in a different order and it's a bit annoying to change the order. We can keep code de-duplication for a future PR.