diff --git a/src/timescoring/visualization.py b/src/timescoring/visualization.py index 44e0816..023467d 100644 --- a/src/timescoring/visualization.py +++ b/src/timescoring/visualization.py @@ -1,6 +1,7 @@ import matplotlib.pyplot as plt from matplotlib.axes import Axes import matplotlib.colors as mc +import matplotlib.ticker as ticker import colorsys import numpy as np @@ -105,8 +106,10 @@ def plotEventScoring(ref: Annotation, hyp: Annotation, # Plot REF TP & FN for event in score.ref.events: + start_idx = int(round(event[0] * score.fs)) + end_idx = int(round(event[1] * score.fs)) # TP - if np.any(score.tpMask[round(event[0] * score.fs):round(event[1] * score.fs)]): + if np.any(score.tpMask[start_idx:end_idx]): color = 'tab:green' else: color = 'tab:purple' @@ -115,11 +118,13 @@ def plotEventScoring(ref: Annotation, hyp: Annotation, # Plot HYP TP & FP for event in score.hyp.events: + start_idx = int(round(event[0] * score.fs)) + end_idx = int(round(event[1] * score.fs)) # FP - if np.all(~score.tpMask[round(event[0] * score.fs):round(event[1] * score.fs)]): + if np.all(~score.tpMask[start_idx:end_idx]): _plotEvent([event[0], event[1] - (1 / ref.fs)], [0.5, 0.5], 'tab:red', ax) # TP - elif np.all(score.tpMask[round(event[0] * score.fs):round(event[1] * score.fs)]): + elif np.all(score.tpMask[start_idx:end_idx]): ax.plot([event[0], event[1] - (1 / ref.fs)], [0.5, 0.5], color='tab:green', linewidth=5, solid_capstyle='butt', linestyle='solid') # Mix TP, FP