Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support float fs for TransformerSeizureDetection (TSD) #5

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/timescoring/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
__email__ = "jonathan.dan at epfl.ch"

from dataclasses import dataclass
from typing import List, Tuple
from typing import List, Tuple, Union

import numpy as np
from nptyping import Bool, NDArray, Shape
Expand All @@ -21,9 +21,9 @@ class Annotation:
"""
events: List[Tuple[int, int]]
mask: NDArray[Shape["Size"], Bool]
fs: int
fs: Union[int, float]

def __init__(self, data, fs: int, numSamples: int = None):
def __init__(self, data, fs:Union[int, float], numSamples: int = None):
"""Initialize an annotation instance.
- Annotation(mask, fs):
This can either be done by providing a binary vector where positive labels are
Expand All @@ -37,7 +37,7 @@ def __init__(self, data, fs: int, numSamples: int = None):

Args:
data (List[Tuple[int, int]] OR NDArray[Bool]): _description_
fs (int): Sampling frequency in Hertz of the annotations.
fs (Union[int, float]): Sampling frequency in Hertz of the annotations.
numSamples (int, optional): Is required when initalizing by providing a
list of (start, stop) tuples. It indicates the number of annotation
samples in the annotation binary mask. It should be left to None if
Expand All @@ -54,7 +54,7 @@ def __init__(self, data, fs: int, numSamples: int = None):
# Build binary mask associated with list of events
mask = np.zeros((numSamples, ), dtype=np.bool_)
for event in data:
mask[round(event[0] * fs):round(event[1] * fs)] = True
mask[int(round(event[0] * fs)):int(round(event[1] * fs))] = True
object.__setattr__(self, 'events', data) # Write to frozen object
object.__setattr__(self, 'mask', mask) # Write to frozen object

Expand Down
25 changes: 15 additions & 10 deletions src/timescoring/scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@
import numpy as np

from .annotations import Annotation
from typing import Union


class _Scoring:
"""" Base class for different scoring methods. The class provides the common
attributes and computation of common scores based on these attributes.
"""
fs: int
fs: Union[int, float]
numSamples: int

refTrue: int
Expand Down Expand Up @@ -54,17 +55,18 @@ def computeScores(self):
class SampleScoring(_Scoring):
"""Calculates performance metrics on the sample by sample basis"""

def __init__(self, ref: Annotation, hyp: Annotation, fs: int = 1):
def __init__(self, ref: Annotation, hyp: Annotation, fs: Union[int, float] = 1):
"""Computes scores on a sample by sample basis.

Args:
ref (Annotation): Reference annotations (ground-truth)
hyp (Annotation): Hypotheses annotations (output of a ML pipeline)
fs (int): Sampling frequency of the labels. Default 1 Hz.
fs (Union[int, float]): Sampling frequency of the labels. Default 1 Hz.
"""
# Resample Data
self.ref = Annotation(ref.events, fs, round(len(ref.mask) / ref.fs * fs))
self.hyp = Annotation(hyp.events, fs, round(len(hyp.mask) / hyp.fs * fs))

self.ref = Annotation(ref.events, fs, int(round(len(ref.mask) / ref.fs * fs)))
self.hyp = Annotation(hyp.events, fs, int(round(len(hyp.mask) / hyp.fs * fs)))

if len(self.ref.mask) != len(self.hyp.mask):
raise ValueError(("The number of samples in the reference Annotation"
Expand Down Expand Up @@ -126,7 +128,7 @@ def __init__(self, ref: Annotation, hyp: Annotation, param: Parameters = Paramet
Defaults to default values.
"""
# Resample data
self.fs = 10 # Operate at a time precision of 10 Hz
self.fs = 1/12 # Operate at a time precision of 10 Hz
self.ref = Annotation(ref.events, self.fs, round(len(ref.mask) / ref.fs * self.fs))
self.hyp = Annotation(hyp.events, self.fs, round(len(hyp.mask) / hyp.fs * self.fs))

Expand All @@ -147,16 +149,19 @@ def __init__(self, ref: Annotation, hyp: Annotation, param: Parameters = Paramet
self.tpMask = np.zeros_like(self.ref.mask)
extendedRef = EventScoring._extendEvents(self.ref, param.toleranceStart, param.toleranceEnd)
for event in extendedRef.events:
relativeOverlap = (np.sum(self.hyp.mask[round(event[0] * self.fs):round(event[1] * self.fs)]) / self.fs
) / (event[1] - event[0])
start_idx = int(round(event[0] * self.fs))
end_idx = int(round(event[1] * self.fs))
relativeOverlap = (np.sum(self.hyp.mask[start_idx:end_idx]) / self.fs) / (event[1] - event[0])
if relativeOverlap > param.minOverlap + 1e-6:
self.tp += 1
self.tpMask[round(event[0] * self.fs):round(event[1] * self.fs)] = 1
self.tpMask[start_idx: end_idx] = 1

# Count False detections
self.fp = 0
for event in self.hyp.events:
if np.any(~self.tpMask[round(event[0] * self.fs):round(event[1] * self.fs)]):
start_idx = int(round(event[0] * self.fs))
end_idx = int(round(event[1] * self.fs))
if np.any(~self.tpMask[start_idx:end_idx]):
self.fp += 1

self.computeScores()
Expand Down
11 changes: 8 additions & 3 deletions src/timescoring/visualization.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import matplotlib.pyplot as plt
from matplotlib.axes import Axes
import matplotlib.colors as mc
import matplotlib.ticker as ticker
import colorsys
import numpy as np

Expand Down Expand Up @@ -105,8 +106,10 @@ def plotEventScoring(ref: Annotation, hyp: Annotation,

# Plot REF TP & FN
for event in score.ref.events:
start_idx = int(round(event[0] * score.fs))
end_idx = int(round(event[1] * score.fs))
# TP
if np.any(score.tpMask[round(event[0] * score.fs):round(event[1] * score.fs)]):
if np.any(score.tpMask[start_idx:end_idx]):
color = 'tab:green'
else:
color = 'tab:purple'
Expand All @@ -115,11 +118,13 @@ def plotEventScoring(ref: Annotation, hyp: Annotation,

# Plot HYP TP & FP
for event in score.hyp.events:
start_idx = int(round(event[0] * score.fs))
end_idx = int(round(event[1] * score.fs))
# FP
if np.all(~score.tpMask[round(event[0] * score.fs):round(event[1] * score.fs)]):
if np.all(~score.tpMask[start_idx:end_idx]):
_plotEvent([event[0], event[1] - (1 / ref.fs)], [0.5, 0.5], 'tab:red', ax)
# TP
elif np.all(score.tpMask[round(event[0] * score.fs):round(event[1] * score.fs)]):
elif np.all(score.tpMask[start_idx:end_idx]):
ax.plot([event[0], event[1] - (1 / ref.fs)], [0.5, 0.5],
color='tab:green', linewidth=5, solid_capstyle='butt', linestyle='solid')
# Mix TP, FP
Expand Down