Skip to content

Commit

Permalink
Added normalization, plotting, and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
locook03 committed Sep 21, 2023
1 parent c022425 commit 9b58130
Show file tree
Hide file tree
Showing 2 changed files with 230 additions and 54 deletions.
102 changes: 68 additions & 34 deletions mne_hfo/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,40 @@
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import mutual_info_score, cohen_kappa_score
from sklearn.preprocessing import Normalizer, MinMaxScaler

from mne_hfo import match_detected_annotations

implemented_labeling = ["raw-detections", "simple-binning", "overlap-predictions"]
implemented_comparisons = ["cohen-kappa", "mutual-info", "similarity-ratio"]

def compare_chart(det_list: list,
out_file,
out_file = None,
normalize = True,
**comp_kw):
"""
Compares similarity between detector results.
Creates a plot of the comparison values in a len(det_list) x len(det_list) plot.
The detectors should be fit to the same data.
Parameters
----------
det_list : List
A list containing Detector instances. Detectors should already been fit
to the data.
out_file : String (Default: None)
The file to write the chart to. If none, plot but do not save.
normalize : Bool (Default: True)
The method to use for comparison. Either 'cohen-kappa' or 'mutual-info'
**comp_kw
All other keywords are passed to compare_detectors().
Returns
-------
fig : matplotlib.pyplot.Figure
Figure containing detector comparison values.
"""

chart_size = (len(det_list), len(det_list))
Expand All @@ -29,9 +52,17 @@ def compare_chart(det_list: list,

comparison_values = np.reshape(comparison_values, chart_size)

if normalize:
transformer = MinMaxScaler().fit(comparison_values)
minMaxVals = transformer.fit_transform(comparison_values)
transformer = Normalizer().fit(minMaxVals)
norm_vals = transformer.fit_transform(minMaxVals)
comparison_values = norm_vals.copy()


print("Plotting......make take a while")
fig, ax = plt.subplots()
ax.imshow(comparison_values, cmap='hot')
im = ax.imshow(comparison_values, cmap='inferno')
ax.set_xticks(np.arange(len(det_list)), labels=[det.__class__() for det in det_list])
ax.set_yticks(np.arange(len(det_list)), labels=[det.__class__() for det in det_list])
plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
Expand All @@ -40,58 +71,73 @@ def compare_chart(det_list: list,
# Loop over data dimensions and create text annotations.
for i in range(len(det_list)):
for j in range(len(det_list)):
if round(float(comparison_values[i, j]),3) > 0.5:
color = 'k'
else:
color = 'w'
text = ax.text(j, i, round(float(comparison_values[i, j]),3),
ha="center", va="center", color="w")
ha="center", va="center", color=color)

cbar = ax.figure.colorbar(im, ax=ax)
cbar.ax.set_ylabel("Similarity (normalized)", rotation=-90, va="bottom")


ax.set_title("Detector Comparison")
fig.tight_layout()
plt.show()
if out_file == None:
plt.show()
else:
plt.savefig(out_file)

return fig



def compare_detectors(clf_1, clf_2,
**comp_kw):
**kwargs):
"""
Compare fits for two classifiers per channel.
Comparisons should be symmetrical.
Parameters
----------
clf_1: Detector
clf_1 : Detector
Detector that contains detections from calling detector.fit()
clf_2: Detector
clf_2 : Detector
Detector that contains detections from calling detector.fit()
method: str
The method to use for comparison. Either 'cohen-kappa' or 'mutual-info'
**kwargs
label_method : String, default : 'overlap-predictions'
Implemented labeling method
comp_method : String, default : 'mutual-info'
Implemented comparison method
bin_width : Int, default : 1
Bin width if labeling requires a bin
Returns
-------
ch_compares: dict
ch_compares : dict
Map of channel name to metric value.
"""

if 'label_method' in comp_kw.keys():
label_method = comp_kw['label_method']
if 'label_method' in kwargs:
label_method = kwargs['label_method']
else:
label_method = 'overlap-predictions'

if 'comp_method' in comp_kw.keys():
comp_method = comp_kw['comp_method']
if 'comp_method' in kwargs:
comp_method = kwargs['comp_method']
else:
comp_method = 'mutual-info'

if 'bin_width' in comp_kw.keys():
bin_width = comp_kw['bin_width']
if 'bin_width' in kwargs:
bin_width = kwargs['bin_width']
else:
bin_width = 1

if 'normalize' in comp_kw.keys():
normalize = comp_kw['normalize']
else:
normalize = False

if not hasattr(clf_1, 'hfo_annotations_'):
raise RuntimeError("clf_1 must be fit to data before using compare")
if not hasattr(clf_2, 'hfo_annotations_'):
Expand Down Expand Up @@ -142,18 +188,6 @@ def compare_detectors(clf_1, clf_2,

ch_compares = comp(df1_labels, df2_labels, ch_names)

if normalize:
ch_vals = list(ch_compares.values())
norm = np.linalg.norm([val for val in ch_vals if not math.isnan(val)])
ch_vals = ch_vals / norm

ch_compares_norm = {}
for n, key in enumerate(ch_compares.keys()):
ch_compares_norm[key] = ch_vals[n]

print(ch_compares)
print(ch_compares_norm)

return ch_compares


Expand Down
Loading

0 comments on commit 9b58130

Please sign in to comment.