From d9e7031c401b67b595fc44fd53e4a49a2b4a7435 Mon Sep 17 00:00:00 2001 From: G Webb Date: Tue, 20 Aug 2024 10:55:44 +0100 Subject: [PATCH 1/5] Fixed edge case bug in TrackToTruth when empty tracks or truths cause an exception to be raised. A test has also been added. Changes to be committed: modified: stonesoup/dataassociator/tests/test_tracktotrack.py modified: stonesoup/dataassociator/tracktotrack.py --- .../dataassociator/tests/test_tracktotrack.py | 21 +++++++++++++++++++ stonesoup/dataassociator/tracktotrack.py | 4 ++++ 2 files changed, 25 insertions(+) diff --git a/stonesoup/dataassociator/tests/test_tracktotrack.py b/stonesoup/dataassociator/tests/test_tracktotrack.py index ba161fae0..144e2fa1c 100644 --- a/stonesoup/dataassociator/tests/test_tracktotrack.py +++ b/stonesoup/dataassociator/tests/test_tracktotrack.py @@ -164,6 +164,27 @@ def test_euclidiantracktotruth(tracks): seconds=6) +def test_empty_track_to_truth(tracks): + associator = TrackToTruth( + association_threshold=10, + consec_pairs_confirm=3, + consec_misses_end=2) + + empty_track = Track() + empty_truth = GroundTruthPath() + association_set = associator.associate_tracks( + truth_set={empty_track, tracks[0]}, + tracks_set={empty_truth, tracks[2], tracks[1], tracks[3]}) + + associated_objects = {obj + for association in association_set + for obj in association.objects} + + assert empty_track not in associated_objects + assert empty_truth not in associated_objects + assert associated_objects == {tracks[0], tracks[1]} + + def test_trackidbased(): associator = TrackIDbased() start_time = datetime.datetime(2019, 1, 1, 14, 0, 0) diff --git a/stonesoup/dataassociator/tracktotrack.py b/stonesoup/dataassociator/tracktotrack.py index 660f812fa..b8e6f58c3 100644 --- a/stonesoup/dataassociator/tracktotrack.py +++ b/stonesoup/dataassociator/tracktotrack.py @@ -258,6 +258,10 @@ def associate_tracks(self, tracks_set: Set[Track], truth_set: Set[GroundTruthPat associations = set() + # Remove tracks and truths with zero length + tracks_set = {track for track in tracks_set if len(track) > 0} + truth_set = {truth for truth in truth_set if len(truth) > 0} + for track in tracks_set: current_truth = None From 7b4a273788e71a551785c59f0d1a9a7f7dece934 Mon Sep 17 00:00:00 2001 From: G Webb Date: Tue, 20 Aug 2024 11:05:09 +0100 Subject: [PATCH 2/5] Minor bug fix in stonesoup/measures/state.py --- stonesoup/measures/state.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stonesoup/measures/state.py b/stonesoup/measures/state.py index 04b5e3b73..01dfad408 100644 --- a/stonesoup/measures/state.py +++ b/stonesoup/measures/state.py @@ -53,7 +53,7 @@ def __call__(self, state1, state2): distance measure between a pair of input :class:`~.State` objects """ - return NotImplementedError + raise NotImplementedError class Euclidean(Measure): From 2895d6e30e59e6cd5cde18aa64bfdd29afc579ce Mon Sep 17 00:00:00 2001 From: G Webb Date: Tue, 20 Aug 2024 11:07:10 +0100 Subject: [PATCH 3/5] Add more explanation to TypeErrors in Base (stonesoup/base.py) --- stonesoup/base.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/stonesoup/base.py b/stonesoup/base.py index 779518bf6..eddc42807 100644 --- a/stonesoup/base.py +++ b/stonesoup/base.py @@ -425,19 +425,20 @@ def __init__(self, *args, **kwargs): try: name, _ = next(prop_iter) except StopIteration: - raise TypeError('too many positional arguments') from None + raise TypeError(f'{cls.__name__} had too many positional arguments') from None if name in kwargs: - raise TypeError(f'multiple values for argument {name!r}') + raise TypeError(f'{cls.__name__} received multiple values for argument {name!r}') setattr(self, name, arg) for name, prop in prop_iter: value = kwargs.pop(name, prop.default) if value is Property.empty: - raise TypeError(f'missing a required argument: {name!r}') + raise TypeError(f'{cls.__name__} is missing a required argument: {name!r}') setattr(self, name, value) if kwargs: - raise TypeError(f'got an unexpected keyword argument {next(iter(kwargs))!r}') + raise TypeError(f'{cls.__name__} got an unexpected keyword argument ' + f'{next(iter(kwargs))!r}') def __repr__(self): # Indents every line From 8ff7a8662cb7e6af3353fd8ec266ca409844384b Mon Sep 17 00:00:00 2001 From: G Webb Date: Tue, 20 Aug 2024 12:42:15 +0100 Subject: [PATCH 4/5] Updated the function interpolate_state_mutable_sequence to use the previous state rather than a new plain State object. Changes to be committed: modified: stonesoup/functions/interpolate.py modified: stonesoup/functions/tests/test_interpolate.py --- stonesoup/functions/interpolate.py | 71 ++++++++++++++++++- stonesoup/functions/tests/test_interpolate.py | 38 +++++++++- 2 files changed, 106 insertions(+), 3 deletions(-) diff --git a/stonesoup/functions/interpolate.py b/stonesoup/functions/interpolate.py index 18f408171..955d77b09 100644 --- a/stonesoup/functions/interpolate.py +++ b/stonesoup/functions/interpolate.py @@ -1,13 +1,26 @@ import copy import datetime import warnings -from typing import Union, List, Iterable +from typing import Union, List, Iterable, Callable import numpy as np from ..types.array import StateVectors from ..types.state import StateMutableSequence, State +try: + # Available from python 3.10 + from itertools import pairwise +except ImportError: + try: + from more_itertools import pairwise + except ImportError: + from itertools import tee + def pairwise(iterable: Iterable): + a, b = tee(iterable) + next(b, None) + return zip(a, b) + def time_range(start_time: datetime.datetime, end_time: datetime.datetime, timestep: datetime.timedelta = datetime.timedelta(seconds=1)) \ @@ -47,6 +60,24 @@ def interpolate_state_mutable_sequence(sms: StateMutableSequence, :class:`~.StateMutableSequence` is returned with the states in the sequence corresponding to ``times``. + When interpolating the previous state is used to create the interpolated state. This means + properties from that previous state are also copied but will not be interpolated + e.g. covariance. + + + Parameters + ---------- + sms: StateMutableSequence + A :class:`~.StateMutableSequence` that should be interpolated + times: Union[datetime.datetime, List[datetime.datetime]] + a time, or a list of times for ``sms`` to be interpolated to. + + Returns + ------- + Union[StateMutableSequence, State] + If a single time is provided then a single state is returned. If a list of times is + provided then a :class:`~.StateMutableSequence` with the same type as ``sms`` is returned + Note ---- This function does **not** extrapolate. Times outside the range of the time range of ``sms`` @@ -101,6 +132,10 @@ def interpolate_state_mutable_sequence(sms: StateMutableSequence, if len(times_to_interpolate) > 0: # Only interpolate if required state_vectors = StateVectors([state.state_vector for state in time_state_dict.values()]) + + # Needed for states with angles present + state_vectors = state_vectors.astype(float) + state_timestamps = [time.timestamp() for time in time_state_dict.keys()] interp_timestamps = [time.timestamp() for time in times_to_interpolate] @@ -110,9 +145,41 @@ def interpolate_state_mutable_sequence(sms: StateMutableSequence, xp=state_timestamps, fp=state_vectors[element_index, :]) + retrieve_previous_state_fun = _get_previous_state(sms) for state_index, time in enumerate(times_to_interpolate): - time_state_dict[time] = State(interp_output[:, state_index], timestamp=time) + original_state_before = retrieve_previous_state_fun(time) + time_state_dict[time] = original_state_before.from_state( + state=original_state_before, + timestamp=time, + state_vector=interp_output[:, state_index]) new_sms.states = [time_state_dict[time] for time in times] return new_sms + + +def _get_previous_state(sms: StateMutableSequence) -> Callable[[datetime.datetime], State]: + """This function produces a function which will return the state before a time in ``sms``. + + Parameters + ---------- + sms: StateMutableSequence + A :class:`~.StateMutableSequence` to provide the states. + + Returns + ------- + Function + This function takes a :class:`datetime.datetime` and will return the State before that + time. If this inner function is called multiple times, the time must not decrease. + + """ + state_iter = iter(pairwise(sms.states)) + state_before, state_after = next(state_iter) + + def inner_fun(t: datetime.datetime) -> State: + nonlocal state_before, state_after + while state_after.timestamp < t: + state_before, state_after = next(state_iter) + return state_before + + return inner_fun diff --git a/stonesoup/functions/tests/test_interpolate.py b/stonesoup/functions/tests/test_interpolate.py index a825f4a06..64ff4434b 100644 --- a/stonesoup/functions/tests/test_interpolate.py +++ b/stonesoup/functions/tests/test_interpolate.py @@ -5,7 +5,7 @@ import pytest from ..interpolate import time_range, interpolate_state_mutable_sequence -from ...types.state import State, StateMutableSequence +from ...types.state import State, StateMutableSequence, GaussianState, StateVector @pytest.mark.parametrize("input_kwargs, expected", @@ -107,3 +107,39 @@ def test_interpolate_error(gen_test_data): with pytest.raises(IndexError): _ = interpolate_state_mutable_sequence(sms, time) + + +def test_interpolate_state_other_properties(): + float_times = [0, 0.1, 0.4, 0.9, 1.6, 2.5, 3.6, 4.9, 6.4, 8.1, 10] + + sms = StateMutableSequence([GaussianState(state_vector=[t], + covar=[[t]], + timestamp=t0+datetime.timedelta(seconds=t)) + for t in float_times + ]) + + interp_float_times = [0, 2, 4, 6, 8, 10] + interp_datetime_times = [t0+datetime.timedelta(seconds=t) for t in interp_float_times] + + new_sms = interpolate_state_mutable_sequence(sms, interp_datetime_times) + + # Test state vector and times + for expected_value, state in zip(interp_float_times, new_sms.states): + assert state.timestamp == t0 + datetime.timedelta(seconds=expected_value) + # assert state.state_vector[0] == expected_value + np.testing.assert_allclose(state.state_vector, StateVector([expected_value])) + assert isinstance(state, GaussianState) + + # Test Covariances + # Ideally the covariance should be the same as the state vector. However interpolating the + # covariance is not implemented and probably shouldn't be. Instead the current method uses the + # previous state’s covariance. In the future it may be better to use the closest state’s + # covariance. Therefore these tests are purposely quite lax and only check the covariance + # value is within a sensible range. + + assert new_sms[0].covar[0][0] == 0 # t=0 + assert 1.6 <= new_sms[1].covar[0][0] <= 2.5 # t=2 + assert 3.6 <= new_sms[2].covar[0][0] <= 4.9 # t=4 + assert 4.9 <= new_sms[3].covar[0][0] <= 6.4 # t=6 + assert 6.4 <= new_sms[4].covar[0][0] <= 8.1 # t=8 + assert new_sms[5].covar[0][0] == 10 # t=10 From b27bcb140903dd0486986e65245bb6b0b72bf1b3 Mon Sep 17 00:00:00 2001 From: G Webb Date: Tue, 20 Aug 2024 13:49:55 +0100 Subject: [PATCH 5/5] fix flake8 error --- stonesoup/functions/interpolate.py | 1 + 1 file changed, 1 insertion(+) diff --git a/stonesoup/functions/interpolate.py b/stonesoup/functions/interpolate.py index 955d77b09..8367e83f2 100644 --- a/stonesoup/functions/interpolate.py +++ b/stonesoup/functions/interpolate.py @@ -16,6 +16,7 @@ from more_itertools import pairwise except ImportError: from itertools import tee + def pairwise(iterable: Iterable): a, b = tee(iterable) next(b, None)