diff --git a/src/pyxdf/pyxdf.py b/src/pyxdf/pyxdf.py index 4703415..79c8b07 100644 --- a/src/pyxdf/pyxdf.py +++ b/src/pyxdf/pyxdf.py @@ -334,7 +334,7 @@ def load_xdf( f.read(chunklen - 2) # Concatenate the signal across chunks - for stream in temp.values(): + for stream_id, stream in temp.items(): if stream.time_stamps: # stream with non-empty list of chunks stream.time_stamps = np.concatenate(stream.time_stamps) @@ -342,6 +342,10 @@ def load_xdf( stream.time_series = list(itertools.chain(*stream.time_series)) else: stream.time_series = np.concatenate(stream.time_series) + # Handle samples that may have arrived out-of-order, sorting + # data by ground truth timestamps if necessary. Identical + # timestamps will remain, but can be handled by dejittering. + stream = _ensure_sorted(stream_id, stream) else: # stream without any chunks stream.time_stamps = np.zeros((0,)) @@ -534,6 +538,25 @@ def _scan_forward(f): return False +def _ensure_sorted(stream_id, stream): + diffs = np.diff(stream.time_stamps) + non_strict_inc_count = np.sum(diffs <= 0) + if non_strict_inc_count > 0: + msg = " stream %d not monotonic %d sample(s) out-of-order. Sorting..." + logger.info(msg, stream_id, non_strict_inc_count) + ind = np.argsort(stream.time_stamps, kind="stable") + stream.time_stamps = stream.time_stamps[ind] + if stream.fmt == "string": + stream.time_series = np.array(stream.time_series)[ind].tolist() + else: + stream.time_series = stream.time_series[ind] + identical_timestamp_count = len(diffs) - np.count_nonzero(diffs) + if identical_timestamp_count > 0: + msg = " stream %d contains %d identical timestamp(s)." + logger.info(msg, stream_id, identical_timestamp_count) + return stream + + def _clock_sync( streams, handle_clock_resets=True, @@ -625,19 +648,24 @@ def _clock_sync( return streams -def _jitter_removal(streams, threshold_seconds=1, threshold_samples=500): +def _detect_breaks(stream, threshold_seconds=1.0, threshold_samples=500): + """Detect breaks in the time_stamps of a stream.""" + # Identify breaks in the time_stamps + diffs = np.diff(stream.time_stamps) + b_breaks = diffs > np.max((threshold_seconds, threshold_samples * stream.tdiff)) + # find indices (+ 1 to compensate for lost sample in np.diff) + break_inds = np.where(b_breaks)[0] + 1 + return break_inds + + +def _jitter_removal(streams, threshold_seconds=1.0, threshold_samples=500): for stream_id, stream in streams.items(): stream.effective_srate = 0 # will be recalculated if possible nsamples = len(stream.time_stamps) stream.segments = [] if nsamples > 0 and stream.srate > 0: - # Identify breaks in the time_stamps - diffs = np.diff(stream.time_stamps) - b_breaks = diffs > np.max( - (threshold_seconds, threshold_samples * stream.tdiff) - ) - # find indices (+ 1 to compensate for lost sample in np.diff) - break_inds = np.where(b_breaks)[0] + 1 + # find break indices + break_inds = _detect_breaks(stream, threshold_seconds, threshold_samples) # Get indices delimiting segments without breaks # 0th sample is a segment start and last sample is a segment stop diff --git a/test/test_jitter_removal.py b/test/test_jitter_removal.py new file mode 100644 index 0000000..23be9e8 --- /dev/null +++ b/test/test_jitter_removal.py @@ -0,0 +1,136 @@ +import numpy as np +from pyxdf.pyxdf import _detect_breaks, _ensure_sorted + + +class MockStreamData: + def __init__(self, time_stamps, tdiff, fmt="float32"): + self.time_stamps = np.array(time_stamps) + self.tdiff = tdiff + self.fmt = fmt + if fmt == "string": + self.time_series = [str(x) for x in time_stamps] + else: + self.time_series = np.array(time_stamps, dtype=fmt) + + +# Monotonic timeseries data. + + +def test_detect_no_breaks_seconds(): + timestamps = list(range(-5, 5)) + stream = MockStreamData(timestamps, 1) + # if diff > 2 and larger 0 * tdiff -> 0 + breaks = _detect_breaks(stream, threshold_seconds=2, threshold_samples=0) + assert breaks.size == 0 + # if diff > 1 and larger 0 * tdiff -> 0 + breaks = _detect_breaks(stream, threshold_seconds=1, threshold_samples=0) + assert breaks.size == 0 + + +def test_detect_no_breaks_samples(): + timestamps = list(range(-5, 5)) + stream = MockStreamData(timestamps, 1) + # if diff > 0 and larger 2 * tdiff -> 0 + breaks = _detect_breaks(stream, threshold_seconds=0, threshold_samples=2) + assert breaks.size == 0 + # if diff > 0 and larger 1 * tdiff -> 0 + breaks = _detect_breaks(stream, threshold_seconds=0, threshold_samples=1) + assert breaks.size == 0 + + +def test_detect_breaks_seconds(): + timestamps = list(range(-5, 5, 2)) + stream = MockStreamData(timestamps, 1) + # if diff > 1 and larger 0 * tdiff -> 4 + breaks = _detect_breaks(stream, threshold_seconds=1, threshold_samples=0) + assert breaks.size == len(timestamps) - 1 + + +def test_detect_breaks_samples(): + timestamps = list(range(-5, 5, 2)) + stream = MockStreamData(timestamps, 1) + # if diff > 0 and larger 1 * tdiff -> 4 + breaks = _detect_breaks(stream, threshold_seconds=0, threshold_samples=1) + assert breaks.size == len(timestamps) - 1 + + +def test_detect_breaks_gap_in_negative(): + timestamps = [-4, 1, 2, 3, 4] + stream = MockStreamData(timestamps, 1) + # if diff > 1 and larger 1 * tdiff -> 1 + breaks = _detect_breaks(stream, threshold_seconds=1, threshold_samples=1) + assert breaks.size == 1 + assert breaks[0] == 1 + timestamps = [-4, -2, -1, 0, 1, 2, 3, 4] + stream = MockStreamData(timestamps, 1) + # if diff > 1 and larger 1 * tdiff -> 1 + breaks = _detect_breaks(stream, threshold_seconds=1, threshold_samples=1) + assert breaks.size == 1 + assert breaks[0] == 1 + # if diff > 0.1 and larger 0 * tdiff -> 7 + breaks = _detect_breaks(stream, threshold_seconds=0.1, threshold_samples=0) + assert breaks.size == len(timestamps) - 1 + + +def test_detect_breaks_gap_in_positive(): + timestamps = [1, 3, 4, 5, 6] + stream = MockStreamData(timestamps, 1) + # if diff > 1 and larger 1 * tdiff -> 1 + breaks = _detect_breaks(stream, threshold_seconds=1, threshold_samples=1) + assert breaks.size == 1 + assert breaks[0] == 1 + # if diff > 0.1 and larger 0 * tdiff -> 4 + breaks = _detect_breaks(stream, threshold_seconds=0.1, threshold_samples=0) + assert breaks.size == len(timestamps) - 1 + + +# Non-monotonic timeseries data. + + +def test_detect_breaks_reverse(): + timestamps = list(reversed(range(-5, 5))) + stream = MockStreamData(timestamps, 1) + stream = _ensure_sorted(1, stream) + # Timeseries should now also be sorted. + assert np.all(stream.time_series == sorted(timestamps)) + # if diff > 1 and larger 0 * tdiff -> 0 + breaks = _detect_breaks(stream, threshold_seconds=1, threshold_samples=0) + assert breaks.size == 0 + + +def test_detect_breaks_non_monotonic_num(): + timestamps = [-4, -5, -3, -2, 0, 0, 1, 2] + stream = MockStreamData(timestamps, 1) + stream = _ensure_sorted(1, stream) + # Timeseries data should now also be sorted. + assert np.all(stream.time_series == sorted(timestamps)) + # if diff > 1 and larger 1 * tdiff -> 1 + breaks = _detect_breaks(stream, threshold_seconds=1, threshold_samples=1) + assert breaks.size == 1 + assert breaks[0] == 4 + # if diff > 2 and larger 2 * tdiff -> 0 + breaks = _detect_breaks(stream, threshold_seconds=2, threshold_samples=2) + assert breaks.size == 0 + # if diff > 0.1 and larger 0 * tdiff -> 6 + breaks = _detect_breaks(stream, threshold_seconds=0.1, threshold_samples=0) + assert breaks.size == len(timestamps) - 2 + assert list(breaks) == [1, 2, 3, 4, 6, 7] + + +def test_detect_breaks_non_monotonic_str(): + timestamps = [-4, -5, -3, -2, 0, 0, 1, 2] + stream = MockStreamData(timestamps, 1, "string") + stream = _ensure_sorted(1, stream) + # Timeseries data should now also be sorted. + assert np.all(stream.time_series == [str(x) for x in sorted(timestamps)]) + # if diff > 1 and larger 1 * tdiff -> 1 + breaks = _detect_breaks(stream, threshold_seconds=1, threshold_samples=1) + assert breaks.size == 1 + assert breaks[0] == 4 + # if diff > 2 and larger 2 * tdiff -> 0 + breaks = _detect_breaks(stream, threshold_seconds=2, threshold_samples=2) + assert breaks.size == 0 + # if diff > 0.1 and larger 0 * tdiff -> 6 + breaks = _detect_breaks(stream, threshold_seconds=0.1, threshold_samples=0) + assert breaks.size == len(timestamps) - 2 + assert list(breaks) == [1, 2, 3, 4, 6, 7]