Skip to content

Commit

Permalink
Add automatic bad detection (#47)
Browse files Browse the repository at this point in the history
  • Loading branch information
marsipu authored Aug 6, 2023
1 parent 08edfeb commit ce3bd13
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 16 deletions.
1 change: 1 addition & 0 deletions mne_pipeline_hd/extra/functions.csv
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
;alias;target;tab;group;matplotlib;mayavi;dependencies;module;pkg_name;func_args
find_bads;Find Bad Channels;MEEG;Compute;Preprocessing;False;False;;operations;basic;meeg,n_jobs
filter_data;Filter;MEEG;Compute;Preprocessing;False;False;;operations;basic;meeg,filter_target,highpass,lowpass,filter_length,l_trans_bandwidth,h_trans_bandwidth,filter_method,iir_params,fir_phase,fir_window,fir_design,skip_by_annotation,fir_pad,n_jobs,enable_cuda,erm_t_limit,bad_interpolation
add_erm_ssp;Empty-Room SSP;MEEG;Compute;Preprocessing;True;False;;operations;basic;meeg,erm_ssp_duration,erm_n_grad,erm_n_mag,erm_n_eeg,n_jobs,show_plots
eeg_reference_raw;Set EEG Reference;MEEG;Compute;Preprocessing;False;False;;operations;basic;meeg,ref_channels
Expand Down
23 changes: 22 additions & 1 deletion mne_pipeline_hd/functions/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,42 @@
import mne
import mne_connectivity
import numpy as np
from mne.preprocessing import ICA
from mne.preprocessing import ICA, find_bad_channels_maxwell

from mne_pipeline_hd.pipeline.loading import MEEG
from mne_pipeline_hd.pipeline.pipeline_utils import (
check_kwargs,
compare_filep,
ismac,
iswin,
get_n_jobs,
)


# Todo: Create docstrings for each function
# =============================================================================
# PREPROCESSING AND GETTING TO EVOKED AND TFR
# =============================================================================
def find_bads(meeg, n_jobs, **kwargs):
raw = meeg.load_raw()

if raw.info["dev_head_t"] is None:
coord_frame = "meg"
else:
coord_frame = "head"

# Set number of CPU-cores to use
os.environ["OMP_NUM_THREADS"] = str(get_n_jobs(n_jobs))

noisy_chs, flat_chs = find_bad_channels_maxwell(
raw, coord_frame=coord_frame, **kwargs
)
logging.info(f"Noisy channels: {noisy_chs}\n" f"Flat channels: {flat_chs}")
raw.info["bads"] = noisy_chs + flat_chs + raw.info["bads"]
meeg.set_bad_channels(raw.info["bads"])
meeg.save_raw(raw)


def filter_data(
meeg,
filter_target,
Expand Down
7 changes: 7 additions & 0 deletions mne_pipeline_hd/gui/gui_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,8 @@ def _html_compatible(text):
return text


# ToDo: Better with QPlainTextEdit(.appendHtml) probably for performance,
# add buffer-limit and tests
class ConsoleWidget(QTextEdit):
"""A Widget displaying formatted stdout/stderr-output"""

Expand Down Expand Up @@ -434,6 +436,7 @@ def __init__(
show_console=False,
close_directly=True,
blocking=False,
return_exception=False,
title=None,
**kwargs,
):
Expand All @@ -442,6 +445,7 @@ def __init__(
self.show_buttons = show_buttons
self.show_console = show_console
self.close_directly = close_directly
self.return_exception = return_exception
self.title = title
self.is_finished = False
self.return_value = None
Expand Down Expand Up @@ -502,6 +506,8 @@ def init_ui(self):

def on_thread_finished(self, return_value):
# Store return value to send it when user closes the dialog
if type(return_value) == ExceptionTuple and not self.return_exception:
return_value = None
self.return_value = return_value
self.is_finished = True
if self.show_buttons:
Expand Down Expand Up @@ -535,6 +541,7 @@ def closeEvent(self, event):
"Closing not possible!",
"You can't close this Dialog before this Thread finished!",
)
event.ignore()


# ToDo: WIP
Expand Down
50 changes: 40 additions & 10 deletions mne_pipeline_hd/gui/loading_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
License: BSD 3-Clause
Github: https://github.com/marsipu/mne-pipeline-hd
"""

import gc
import logging
import os
import re
import shutil
Expand Down Expand Up @@ -47,7 +48,9 @@
QWizardPage,
)
from matplotlib import pyplot as plt
from mne.preprocessing import find_bad_channels_maxwell

from mne_pipeline_hd.functions.operations import find_bads
from mne_pipeline_hd.functions.plot import (
plot_ica_components,
plot_ica_sources,
Expand Down Expand Up @@ -78,7 +81,7 @@
from mne_pipeline_hd.gui.models import AddFilesModel
from mne_pipeline_hd.gui.parameter_widgets import ComboGui
from mne_pipeline_hd.pipeline.loading import FSMRI, Group, MEEG
from mne_pipeline_hd.pipeline.pipeline_utils import compare_filep
from mne_pipeline_hd.pipeline.pipeline_utils import compare_filep, QS


def index_parser(index, all_items):
Expand Down Expand Up @@ -1062,6 +1065,22 @@ def __init__(self, main_win, mode, title):
self.setLayout(layout)


class FindBadsDialog(QDialog):
def __init__(self, parent):
super().__init__(parent)
self.pw = parent

self.values

self.init_ui()
self.open()

def init_ui(self):
layout = QVBoxLayout()

self.setLayout(layout)


class CopyBadsDialog(QDialog):
def __init__(self, parent_w):
super().__init__(parent_w)
Expand Down Expand Up @@ -1134,6 +1153,7 @@ def __init__(self, main_win):
self.ct = main_win.ct
self.pr = main_win.ct.pr
self.setWindowTitle("Assign bad_channels for your files")
self.all_files = self.pr.all_meeg + self.pr.all_erm
self.bad_chkbts = dict()
self.info_dict = dict()
self.current_obj = None
Expand Down Expand Up @@ -1164,6 +1184,10 @@ def init_ui(self):
plot_bt.clicked.connect(self.plot_raw_bad)
self.bt_layout.addWidget(plot_bt)

find_bads_bt = QPushButton("Find bads")
find_bads_bt.clicked.connect(self.find_bads)
self.bt_layout.addWidget(find_bads_bt)

copy_bt = QPushButton("Copy Bads")
copy_bt.clicked.connect(partial(CopyBadsDialog, self))
self.bt_layout.addWidget(copy_bt)
Expand Down Expand Up @@ -1253,14 +1277,6 @@ def bad_dict_selected(self, current, _):

self.make_bad_chbxs()

def _assign_bad_channels(self, bad_channels):
# Directly replace value in bad_channels_dict
# (needed for first-time assignment)
self.current_obj.pr.meeg_bad_channels[self.current_obj.name] = bad_channels
# Restore/Establish reference to direct object-attribute
self.current_obj.bad_channels = bad_channels
self.files_widget.content_changed()

def bad_ckbx_assigned(self):
bad_channels = [ch for ch in self.bad_chkbts if self.bad_chkbts[ch].isChecked()]
self.current_obj.set_bad_channels(bad_channels)
Expand Down Expand Up @@ -1297,6 +1313,20 @@ def plot_raw_bad(self):
plot_raw(self.current_obj, show_plots=True, close_func=self.get_selected_bads)
plot_dialog.close()

def find_bads(self):
wd = WorkerDialog(
self,
find_bads,
meeg=self.current_obj,
n_jobs=QS().value("n_jobs"),
show_console=True,
show_buttons=True,
close_directly=False,
return_exception=False,
title="Finding bads with maxwell filter...",
)
wd.thread_finished.connect(self.update_selection)

def resizeEvent(self, event):
if self.current_obj:
self.make_bad_chbxs()
Expand Down
6 changes: 1 addition & 5 deletions mne_pipeline_hd/pipeline/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,6 @@
import matplotlib.pyplot as plt
import mne
import numpy as np

# =============================================================================
# LOADING FUNCTIONS
# =============================================================================
from tqdm import tqdm

from mne_pipeline_hd.pipeline.pipeline_utils import (
Expand Down Expand Up @@ -1413,7 +1409,7 @@ def _get_available_labels(self):
subjects_dir=self.subjects_dir,
verbose="warning",
)
except RuntimeError:
except (RuntimeError, OSError):
print(f"Parcellation {parcellation} could not be loaded!")

return labels
Expand Down
11 changes: 11 additions & 0 deletions mne_pipeline_hd/pipeline/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import inspect
import json
import logging
import multiprocessing
import os
import sys
from ast import literal_eval
Expand All @@ -29,6 +30,16 @@
islin = not ismac and not iswin


def get_n_jobs(n_jobs):
"""Get the number of jobs to use for parallel processing"""
if n_jobs == -1 or n_jobs in ["auto", "max"]:
n_cores = multiprocessing.cpu_count()
else:
n_cores = int(n_jobs)

return n_cores


def encode_tuples(input_dict):
"""Encode tuples in a dictionary, because JSON does not recognize them
(CAVE: input_dict is changed in place)"""
Expand Down

0 comments on commit ce3bd13

Please sign in to comment.