From ce3bd133677ed41a4cbb6c1f2140454aef42aff3 Mon Sep 17 00:00:00 2001 From: Martin Schulz Date: Sun, 6 Aug 2023 21:27:44 +0200 Subject: [PATCH] Add automatic bad detection (#47) --- mne_pipeline_hd/extra/functions.csv | 1 + mne_pipeline_hd/functions/operations.py | 23 +++++++++- mne_pipeline_hd/gui/gui_utils.py | 7 +++ mne_pipeline_hd/gui/loading_widgets.py | 50 +++++++++++++++++----- mne_pipeline_hd/pipeline/loading.py | 6 +-- mne_pipeline_hd/pipeline/pipeline_utils.py | 11 +++++ 6 files changed, 82 insertions(+), 16 deletions(-) diff --git a/mne_pipeline_hd/extra/functions.csv b/mne_pipeline_hd/extra/functions.csv index 1cf7a748..aa696478 100644 --- a/mne_pipeline_hd/extra/functions.csv +++ b/mne_pipeline_hd/extra/functions.csv @@ -1,4 +1,5 @@ ;alias;target;tab;group;matplotlib;mayavi;dependencies;module;pkg_name;func_args +find_bads;Find Bad Channels;MEEG;Compute;Preprocessing;False;False;;operations;basic;meeg,n_jobs filter_data;Filter;MEEG;Compute;Preprocessing;False;False;;operations;basic;meeg,filter_target,highpass,lowpass,filter_length,l_trans_bandwidth,h_trans_bandwidth,filter_method,iir_params,fir_phase,fir_window,fir_design,skip_by_annotation,fir_pad,n_jobs,enable_cuda,erm_t_limit,bad_interpolation add_erm_ssp;Empty-Room SSP;MEEG;Compute;Preprocessing;True;False;;operations;basic;meeg,erm_ssp_duration,erm_n_grad,erm_n_mag,erm_n_eeg,n_jobs,show_plots eeg_reference_raw;Set EEG Reference;MEEG;Compute;Preprocessing;False;False;;operations;basic;meeg,ref_channels diff --git a/mne_pipeline_hd/functions/operations.py b/mne_pipeline_hd/functions/operations.py index 716e2c2d..577ad197 100644 --- a/mne_pipeline_hd/functions/operations.py +++ b/mne_pipeline_hd/functions/operations.py @@ -24,7 +24,7 @@ import mne import mne_connectivity import numpy as np -from mne.preprocessing import ICA +from mne.preprocessing import ICA, find_bad_channels_maxwell from mne_pipeline_hd.pipeline.loading import MEEG from mne_pipeline_hd.pipeline.pipeline_utils import ( @@ -32,6 +32,7 @@ compare_filep, ismac, iswin, + get_n_jobs, ) @@ -39,6 +40,26 @@ # ============================================================================= # PREPROCESSING AND GETTING TO EVOKED AND TFR # ============================================================================= +def find_bads(meeg, n_jobs, **kwargs): + raw = meeg.load_raw() + + if raw.info["dev_head_t"] is None: + coord_frame = "meg" + else: + coord_frame = "head" + + # Set number of CPU-cores to use + os.environ["OMP_NUM_THREADS"] = str(get_n_jobs(n_jobs)) + + noisy_chs, flat_chs = find_bad_channels_maxwell( + raw, coord_frame=coord_frame, **kwargs + ) + logging.info(f"Noisy channels: {noisy_chs}\n" f"Flat channels: {flat_chs}") + raw.info["bads"] = noisy_chs + flat_chs + raw.info["bads"] + meeg.set_bad_channels(raw.info["bads"]) + meeg.save_raw(raw) + + def filter_data( meeg, filter_target, diff --git a/mne_pipeline_hd/gui/gui_utils.py b/mne_pipeline_hd/gui/gui_utils.py index 1fe05198..322cea78 100644 --- a/mne_pipeline_hd/gui/gui_utils.py +++ b/mne_pipeline_hd/gui/gui_utils.py @@ -229,6 +229,8 @@ def _html_compatible(text): return text +# ToDo: Better with QPlainTextEdit(.appendHtml) probably for performance, +# add buffer-limit and tests class ConsoleWidget(QTextEdit): """A Widget displaying formatted stdout/stderr-output""" @@ -434,6 +436,7 @@ def __init__( show_console=False, close_directly=True, blocking=False, + return_exception=False, title=None, **kwargs, ): @@ -442,6 +445,7 @@ def __init__( self.show_buttons = show_buttons self.show_console = show_console self.close_directly = close_directly + self.return_exception = return_exception self.title = title self.is_finished = False self.return_value = None @@ -502,6 +506,8 @@ def init_ui(self): def on_thread_finished(self, return_value): # Store return value to send it when user closes the dialog + if type(return_value) == ExceptionTuple and not self.return_exception: + return_value = None self.return_value = return_value self.is_finished = True if self.show_buttons: @@ -535,6 +541,7 @@ def closeEvent(self, event): "Closing not possible!", "You can't close this Dialog before this Thread finished!", ) + event.ignore() # ToDo: WIP diff --git a/mne_pipeline_hd/gui/loading_widgets.py b/mne_pipeline_hd/gui/loading_widgets.py index 642c0b67..fe4503c3 100644 --- a/mne_pipeline_hd/gui/loading_widgets.py +++ b/mne_pipeline_hd/gui/loading_widgets.py @@ -4,7 +4,8 @@ License: BSD 3-Clause Github: https://github.com/marsipu/mne-pipeline-hd """ - +import gc +import logging import os import re import shutil @@ -47,7 +48,9 @@ QWizardPage, ) from matplotlib import pyplot as plt +from mne.preprocessing import find_bad_channels_maxwell +from mne_pipeline_hd.functions.operations import find_bads from mne_pipeline_hd.functions.plot import ( plot_ica_components, plot_ica_sources, @@ -78,7 +81,7 @@ from mne_pipeline_hd.gui.models import AddFilesModel from mne_pipeline_hd.gui.parameter_widgets import ComboGui from mne_pipeline_hd.pipeline.loading import FSMRI, Group, MEEG -from mne_pipeline_hd.pipeline.pipeline_utils import compare_filep +from mne_pipeline_hd.pipeline.pipeline_utils import compare_filep, QS def index_parser(index, all_items): @@ -1062,6 +1065,22 @@ def __init__(self, main_win, mode, title): self.setLayout(layout) +class FindBadsDialog(QDialog): + def __init__(self, parent): + super().__init__(parent) + self.pw = parent + + self.values + + self.init_ui() + self.open() + + def init_ui(self): + layout = QVBoxLayout() + + self.setLayout(layout) + + class CopyBadsDialog(QDialog): def __init__(self, parent_w): super().__init__(parent_w) @@ -1134,6 +1153,7 @@ def __init__(self, main_win): self.ct = main_win.ct self.pr = main_win.ct.pr self.setWindowTitle("Assign bad_channels for your files") + self.all_files = self.pr.all_meeg + self.pr.all_erm self.bad_chkbts = dict() self.info_dict = dict() self.current_obj = None @@ -1164,6 +1184,10 @@ def init_ui(self): plot_bt.clicked.connect(self.plot_raw_bad) self.bt_layout.addWidget(plot_bt) + find_bads_bt = QPushButton("Find bads") + find_bads_bt.clicked.connect(self.find_bads) + self.bt_layout.addWidget(find_bads_bt) + copy_bt = QPushButton("Copy Bads") copy_bt.clicked.connect(partial(CopyBadsDialog, self)) self.bt_layout.addWidget(copy_bt) @@ -1253,14 +1277,6 @@ def bad_dict_selected(self, current, _): self.make_bad_chbxs() - def _assign_bad_channels(self, bad_channels): - # Directly replace value in bad_channels_dict - # (needed for first-time assignment) - self.current_obj.pr.meeg_bad_channels[self.current_obj.name] = bad_channels - # Restore/Establish reference to direct object-attribute - self.current_obj.bad_channels = bad_channels - self.files_widget.content_changed() - def bad_ckbx_assigned(self): bad_channels = [ch for ch in self.bad_chkbts if self.bad_chkbts[ch].isChecked()] self.current_obj.set_bad_channels(bad_channels) @@ -1297,6 +1313,20 @@ def plot_raw_bad(self): plot_raw(self.current_obj, show_plots=True, close_func=self.get_selected_bads) plot_dialog.close() + def find_bads(self): + wd = WorkerDialog( + self, + find_bads, + meeg=self.current_obj, + n_jobs=QS().value("n_jobs"), + show_console=True, + show_buttons=True, + close_directly=False, + return_exception=False, + title="Finding bads with maxwell filter...", + ) + wd.thread_finished.connect(self.update_selection) + def resizeEvent(self, event): if self.current_obj: self.make_bad_chbxs() diff --git a/mne_pipeline_hd/pipeline/loading.py b/mne_pipeline_hd/pipeline/loading.py index 2ac088a9..a3c1db96 100644 --- a/mne_pipeline_hd/pipeline/loading.py +++ b/mne_pipeline_hd/pipeline/loading.py @@ -23,10 +23,6 @@ import matplotlib.pyplot as plt import mne import numpy as np - -# ============================================================================= -# LOADING FUNCTIONS -# ============================================================================= from tqdm import tqdm from mne_pipeline_hd.pipeline.pipeline_utils import ( @@ -1413,7 +1409,7 @@ def _get_available_labels(self): subjects_dir=self.subjects_dir, verbose="warning", ) - except RuntimeError: + except (RuntimeError, OSError): print(f"Parcellation {parcellation} could not be loaded!") return labels diff --git a/mne_pipeline_hd/pipeline/pipeline_utils.py b/mne_pipeline_hd/pipeline/pipeline_utils.py index 9f7095e9..16e458ae 100644 --- a/mne_pipeline_hd/pipeline/pipeline_utils.py +++ b/mne_pipeline_hd/pipeline/pipeline_utils.py @@ -8,6 +8,7 @@ import inspect import json import logging +import multiprocessing import os import sys from ast import literal_eval @@ -29,6 +30,16 @@ islin = not ismac and not iswin +def get_n_jobs(n_jobs): + """Get the number of jobs to use for parallel processing""" + if n_jobs == -1 or n_jobs in ["auto", "max"]: + n_cores = multiprocessing.cpu_count() + else: + n_cores = int(n_jobs) + + return n_cores + + def encode_tuples(input_dict): """Encode tuples in a dictionary, because JSON does not recognize them (CAVE: input_dict is changed in place)"""