From 9b581303bb94696a40ada27eb0f9aac217d94601 Mon Sep 17 00:00:00 2001 From: Logan Cook Date: Wed, 20 Sep 2023 21:45:31 -0500 Subject: [PATCH] Added normalization, plotting, and tests --- mne_hfo/compare.py | 102 ++++++++++++------- mne_hfo/tests/test_compare.py | 182 ++++++++++++++++++++++++++++++---- 2 files changed, 230 insertions(+), 54 deletions(-) diff --git a/mne_hfo/compare.py b/mne_hfo/compare.py index 2fb593c..b742b2f 100644 --- a/mne_hfo/compare.py +++ b/mne_hfo/compare.py @@ -3,6 +3,7 @@ import pandas as pd import matplotlib.pyplot as plt from sklearn.metrics import mutual_info_score, cohen_kappa_score +from sklearn.preprocessing import Normalizer, MinMaxScaler from mne_hfo import match_detected_annotations @@ -10,10 +11,32 @@ implemented_comparisons = ["cohen-kappa", "mutual-info", "similarity-ratio"] def compare_chart(det_list: list, - out_file, + out_file = None, + normalize = True, **comp_kw): """ - + Compares similarity between detector results. + Creates a plot of the comparison values in a len(det_list) x len(det_list) plot. + + + The detectors should be fit to the same data. + + Parameters + ---------- + det_list : List + A list containing Detector instances. Detectors should already been fit + to the data. + out_file : String (Default: None) + The file to write the chart to. If none, plot but do not save. + normalize : Bool (Default: True) + The method to use for comparison. Either 'cohen-kappa' or 'mutual-info' + **comp_kw + All other keywords are passed to compare_detectors(). + + Returns + ------- + fig : matplotlib.pyplot.Figure + Figure containing detector comparison values. """ chart_size = (len(det_list), len(det_list)) @@ -29,9 +52,17 @@ def compare_chart(det_list: list, comparison_values = np.reshape(comparison_values, chart_size) + if normalize: + transformer = MinMaxScaler().fit(comparison_values) + minMaxVals = transformer.fit_transform(comparison_values) + transformer = Normalizer().fit(minMaxVals) + norm_vals = transformer.fit_transform(minMaxVals) + comparison_values = norm_vals.copy() + + print("Plotting......make take a while") fig, ax = plt.subplots() - ax.imshow(comparison_values, cmap='hot') + im = ax.imshow(comparison_values, cmap='inferno') ax.set_xticks(np.arange(len(det_list)), labels=[det.__class__() for det in det_list]) ax.set_yticks(np.arange(len(det_list)), labels=[det.__class__() for det in det_list]) plt.setp(ax.get_xticklabels(), rotation=45, ha="right", @@ -40,17 +71,30 @@ def compare_chart(det_list: list, # Loop over data dimensions and create text annotations. for i in range(len(det_list)): for j in range(len(det_list)): + if round(float(comparison_values[i, j]),3) > 0.5: + color = 'k' + else: + color = 'w' text = ax.text(j, i, round(float(comparison_values[i, j]),3), - ha="center", va="center", color="w") + ha="center", va="center", color=color) + + cbar = ax.figure.colorbar(im, ax=ax) + cbar.ax.set_ylabel("Similarity (normalized)", rotation=-90, va="bottom") + ax.set_title("Detector Comparison") fig.tight_layout() - plt.show() + if out_file == None: + plt.show() + else: + plt.savefig(out_file) + + return fig def compare_detectors(clf_1, clf_2, - **comp_kw): + **kwargs): """ Compare fits for two classifiers per channel. @@ -58,40 +102,42 @@ def compare_detectors(clf_1, clf_2, Parameters ---------- - clf_1: Detector + clf_1 : Detector Detector that contains detections from calling detector.fit() - clf_2: Detector + clf_2 : Detector Detector that contains detections from calling detector.fit() - method: str - The method to use for comparison. Either 'cohen-kappa' or 'mutual-info' + **kwargs + label_method : String, default : 'overlap-predictions' + Implemented labeling method + comp_method : String, default : 'mutual-info' + Implemented comparison method + bin_width : Int, default : 1 + Bin width if labeling requires a bin + + Returns ------- - ch_compares: dict + ch_compares : dict Map of channel name to metric value. """ - if 'label_method' in comp_kw.keys(): - label_method = comp_kw['label_method'] + if 'label_method' in kwargs: + label_method = kwargs['label_method'] else: label_method = 'overlap-predictions' - if 'comp_method' in comp_kw.keys(): - comp_method = comp_kw['comp_method'] + if 'comp_method' in kwargs: + comp_method = kwargs['comp_method'] else: comp_method = 'mutual-info' - if 'bin_width' in comp_kw.keys(): - bin_width = comp_kw['bin_width'] + if 'bin_width' in kwargs: + bin_width = kwargs['bin_width'] else: bin_width = 1 - if 'normalize' in comp_kw.keys(): - normalize = comp_kw['normalize'] - else: - normalize = False - if not hasattr(clf_1, 'hfo_annotations_'): raise RuntimeError("clf_1 must be fit to data before using compare") if not hasattr(clf_2, 'hfo_annotations_'): @@ -142,18 +188,6 @@ def compare_detectors(clf_1, clf_2, ch_compares = comp(df1_labels, df2_labels, ch_names) - if normalize: - ch_vals = list(ch_compares.values()) - norm = np.linalg.norm([val for val in ch_vals if not math.isnan(val)]) - ch_vals = ch_vals / norm - - ch_compares_norm = {} - for n, key in enumerate(ch_compares.keys()): - ch_compares_norm[key] = ch_vals[n] - - print(ch_compares) - print(ch_compares_norm) - return ch_compares diff --git a/mne_hfo/tests/test_compare.py b/mne_hfo/tests/test_compare.py index b11ff0b..641658d 100644 --- a/mne_hfo/tests/test_compare.py +++ b/mne_hfo/tests/test_compare.py @@ -1,10 +1,72 @@ import pytest +from mne import Annotations + from mne_hfo import create_annotations_df from mne_hfo.compare import compare_detectors from mne_hfo.detect import RMSDetector from numpy.testing import assert_almost_equal +@pytest.fixture(scope='function') +def create_detector1(): + # Create two dummy RMSDetector objects. + rms1 = RMSDetector() + + # Create two event dataframes with expected columns. We will + # consider df1 to be predictions from rms1 and df2 to be predictions + # from rms2 + sfreq = 1000 + # create dummy reference annotations + onset1 = [8, 12.6, 59.9, 99.2, 150.4] + offset1 = [9.7300, 14.870, 66.1, 101.22, 156.1] + duration1 = [offset - onset for onset, offset in zip(onset1, offset1)] + + hfo_annotations = [] + description = ['hfo'] * len(onset1) + ch_names = [['A1'] for _ in range (len(onset1))] + ch_hfo_events = Annotations(onset=onset1, duration=duration1, + description=description, + ch_names=ch_names) + hfo_annotations.append(ch_hfo_events) + rms1.hfo_annotations_ = hfo_annotations[0] + rms1.sfreq = sfreq + rms1.ch_names = ['A1'] + + # Gives dummy detector a length of the data used. + rms1.n_times = 200*1000 + + return rms1 + +@pytest.fixture(scope='function') +def create_detector2(): + # Create two dummy RMSDetector objects. + rms2 = RMSDetector() + + # Create two event dataframes with expected columns. We will + # consider df1 to be predictions from rms1 and df2 to be predictions + # from rms2 + sfreq = 1000 + + # create dummy predicted HFO annotations + onset2 = [2, 60.1, 98.3, 110.23] + offset2 = [6.93, 65.6, 101.45, 112.89] + duration2 = [offset - onset for onset, offset in zip(onset2, offset2)] + + hfo_annotations = [] + description = ['hfo'] * len(onset2) + ch_names = [['A1'] for _ in range (len(onset2))] + ch_hfo_events = Annotations(onset=onset2, duration=duration2, + description=description, + ch_names=ch_names) + hfo_annotations.append(ch_hfo_events) + rms2.hfo_annotations_ = hfo_annotations[0] + rms2.sfreq = sfreq + rms2.ch_names = ['A1'] + + # Gives dummy detector a length of the data used. + rms2.n_times = 200*1000 + + return rms2 def test_compare_detectors(): """Test comparison metrics.""" @@ -16,7 +78,9 @@ def test_compare_detectors(): # Make sure you can't run compare when Detectors haven't been fit with pytest.raises(RuntimeError, match='clf_1 must be fit' ' to data before using compare'): - compare_detectors(rms1, rms2, method="mutual-info") + compare_detectors(rms1, rms2, + comp_method="mutual-info", + label_method="overlap-predictions") # Create two event dataframes with expected columns. We will # consider df1 to be predictions from rms1 and df2 to be predictions @@ -26,31 +90,71 @@ def test_compare_detectors(): onset1 = [8, 12.6, 59.9, 99.2, 150.4] offset1 = [9.7300, 14.870, 66.1, 101.22, 156.1] duration1 = [offset - onset for onset, offset in zip(onset1, offset1)] - ch_name = ['A1'] * len(onset1) - annotation_label = ['hfo'] * len(onset1) - annot_df1 = create_annotations_df(onset1, duration1, ch_name, - sfreq, annotation_label) - annot_df1['sample'] = annot_df1['onset'] * sfreq + + # Make sure you can't run compare when Detectors haven't been fit + with pytest.raises(RuntimeError, match='clf_1 must be fit' + ' to data before using compare'): + compare_detectors(rms1, rms2, + comp_method="mutual-info", + label_method="overlap-predictions") + + hfo_annotations = [] + description = ['hfo'] * len(onset1) + ch_names = [['A1'] for _ in range (len(onset1))] + ch_hfo_events = Annotations(onset=onset1, duration=duration1, + description=description, + ch_names=ch_names) + hfo_annotations.append(ch_hfo_events) + rms1.hfo_annotations_ = hfo_annotations[0] + rms1.sfreq = sfreq + rms1.ch_names = ['A1'] + + # Make sure you can't run compare when Detectors haven't been fit + with pytest.raises(RuntimeError, match='clf_2 must be fit' + ' to data before using compare'): + compare_detectors(rms1, rms2, + comp_method="mutual-info", + label_method="overlap-predictions") # create dummy predicted HFO annotations onset2 = [2, 60.1, 98.3, 110.23] offset2 = [6.93, 65.6, 101.45, 112.89] duration2 = [offset - onset for onset, offset in zip(onset2, offset2)] - ch_name = ['A1'] * len(onset2) - annotation_label = ['hfo'] * len(onset2) - annot_df2 = create_annotations_df(onset2, duration2, ch_name, - sfreq, annotation_label) - annot_df2['sample'] = annot_df2['onset'] * sfreq - # Attach the annotation dataframes to the dummy detectors - rms1.df_ = annot_df1 + hfo_annotations = [] + description = ['hfo'] * len(onset2) + ch_names = [['A1'] for _ in range (len(onset2))] + ch_hfo_events = Annotations(onset=onset2, duration=duration2, + description=description, + ch_names=ch_names) + hfo_annotations.append(ch_hfo_events) + rms2.hfo_annotations_ = hfo_annotations[0] + rms2.sfreq = sfreq + rms2.ch_names = ['A1'] + + # Gives dummy detector a length of the data used. + rms1.n_times = 200*1000 - # Make sure you can't run compare when Detectors haven't been fit with pytest.raises(RuntimeError, match='clf_2 must be fit' - ' to data before using compare'): - compare_detectors(rms1, rms2, method="mutual-info") + ' to data before using compare'): + compare_detectors(rms1, rms2, + comp_method="mutual-info", + label_method="overlap-predictions") + + rms2.n_times = 100*1000 - rms2.df_ = annot_df2 + # Make sure the length of the raw data for each classifier are identical + with pytest.raises(RuntimeError, match='clf_1 and clf_2 must be fit' + ' on the same length of data'): + compare_detectors(rms1, rms2, + comp_method="mutual-info", + label_method="overlap-predictions") + + rms2.n_times = 200*1000 + +def test_comparison_methods(create_detector1, create_detector2): + det1 = create_detector1 + det2 = create_detector2 # We expect the labels from rms1 to be [False, True, True, True, # True, False, True] @@ -60,17 +164,55 @@ def test_compare_detectors(): expected_mutual_info = 0.20218548540814557 expected_kappa_score = -0.5217391304347827 + expected_similarity = 0.28571429 # Calculate mutual info and assert almost equal - mutual_info = compare_detectors(rms1, rms2, method="mutual-info") + mutual_info = compare_detectors(det1, det2, label_method="overlap-predictions", comp_method="mutual-info") mi = mutual_info['A1'] assert_almost_equal(mi, expected_mutual_info, decimal=5) # Calculate kappa score and assert almost equal - kappa = compare_detectors(rms1, rms2, method="cohen-kappa") + kappa = compare_detectors(det1, det2, label_method="overlap-predictions", comp_method="cohen-kappa") k = kappa['A1'] assert_almost_equal(k, expected_kappa_score, decimal=5) + similarity = compare_detectors(det1, det2, label_method="overlap-predictions", comp_method="similarity-ratio") + s = similarity['A1'] + assert_almost_equal(s, expected_similarity, decimal=5) + + # Make sure you can't run a random method + with pytest.raises(NotImplementedError): + compare_detectors(det1, det2, label_method="overlap=predictions", comp_method="average") + + + +def test_labeling_methods(create_detector1, create_detector2): + det1 = create_detector1 + det2 = create_detector2 + + # We expect the labels from rms1 to be [False, True, True, True, + # True, False, True] + # and the labels from rms2 to be [True, False, False, True, True, + # True, False] + # which gives the following mutual info and kappa scores + + expected_binning_score = 0.86 + expected_raw_detections = 0.88575 + + # Calculate kappa score and assert almost equal + binning = compare_detectors(det1, det2, + label_method="simple-binning", + comp_method="similarity-ratio", + bin_width=1000) + b = binning['A1'] + assert_almost_equal(b, expected_binning_score, decimal=2) + + raw_detections = compare_detectors(det1, det2, + label_method="raw-detections", + comp_method="similarity-ratio") + r = raw_detections['A1'] + assert_almost_equal(r, expected_raw_detections, decimal=5) + # Make sure you can't run a random method with pytest.raises(NotImplementedError): - compare_detectors(rms1, rms2, method="average") + compare_detectors(det1, det2, label_method="matching", comp_method="similarity-ratio") \ No newline at end of file