Skip to content

Commit

Permalink
Merge pull request #1074 from dstl/small_improvements_08_24
Browse files Browse the repository at this point in the history
Minor Improvements August 24
  • Loading branch information
sdhiscocks authored Sep 3, 2024
2 parents 9b03203 + b27bcb1 commit 58c41c4
Show file tree
Hide file tree
Showing 6 changed files with 138 additions and 8 deletions.
9 changes: 5 additions & 4 deletions stonesoup/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,19 +425,20 @@ def __init__(self, *args, **kwargs):
try:
name, _ = next(prop_iter)
except StopIteration:
raise TypeError('too many positional arguments') from None
raise TypeError(f'{cls.__name__} had too many positional arguments') from None
if name in kwargs:
raise TypeError(f'multiple values for argument {name!r}')
raise TypeError(f'{cls.__name__} received multiple values for argument {name!r}')
setattr(self, name, arg)

for name, prop in prop_iter:
value = kwargs.pop(name, prop.default)
if value is Property.empty:
raise TypeError(f'missing a required argument: {name!r}')
raise TypeError(f'{cls.__name__} is missing a required argument: {name!r}')
setattr(self, name, value)

if kwargs:
raise TypeError(f'got an unexpected keyword argument {next(iter(kwargs))!r}')
raise TypeError(f'{cls.__name__} got an unexpected keyword argument '
f'{next(iter(kwargs))!r}')

def __repr__(self):
# Indents every line
Expand Down
21 changes: 21 additions & 0 deletions stonesoup/dataassociator/tests/test_tracktotrack.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,27 @@ def test_euclidiantracktotruth(tracks):
seconds=6)


def test_empty_track_to_truth(tracks):
associator = TrackToTruth(
association_threshold=10,
consec_pairs_confirm=3,
consec_misses_end=2)

empty_track = Track()
empty_truth = GroundTruthPath()
association_set = associator.associate_tracks(
truth_set={empty_track, tracks[0]},
tracks_set={empty_truth, tracks[2], tracks[1], tracks[3]})

associated_objects = {obj
for association in association_set
for obj in association.objects}

assert empty_track not in associated_objects
assert empty_truth not in associated_objects
assert associated_objects == {tracks[0], tracks[1]}


def test_trackidbased():
associator = TrackIDbased()
start_time = datetime.datetime(2019, 1, 1, 14, 0, 0)
Expand Down
4 changes: 4 additions & 0 deletions stonesoup/dataassociator/tracktotrack.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,10 @@ def associate_tracks(self, tracks_set: Set[Track], truth_set: Set[GroundTruthPat

associations = set()

# Remove tracks and truths with zero length
tracks_set = {track for track in tracks_set if len(track) > 0}
truth_set = {truth for truth in truth_set if len(truth) > 0}

for track in tracks_set:

current_truth = None
Expand Down
72 changes: 70 additions & 2 deletions stonesoup/functions/interpolate.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,27 @@
import copy
import datetime
import warnings
from typing import Union, List, Iterable
from typing import Union, List, Iterable, Callable

import numpy as np

from ..types.array import StateVectors
from ..types.state import StateMutableSequence, State

try:
# Available from python 3.10
from itertools import pairwise
except ImportError:
try:
from more_itertools import pairwise
except ImportError:
from itertools import tee

def pairwise(iterable: Iterable):
a, b = tee(iterable)
next(b, None)
return zip(a, b)


def time_range(start_time: datetime.datetime, end_time: datetime.datetime,
timestep: datetime.timedelta = datetime.timedelta(seconds=1)) \
Expand Down Expand Up @@ -47,6 +61,24 @@ def interpolate_state_mutable_sequence(sms: StateMutableSequence,
:class:`~.StateMutableSequence` is returned with the states in the sequence corresponding to
``times``.
When interpolating the previous state is used to create the interpolated state. This means
properties from that previous state are also copied but will not be interpolated
e.g. covariance.
Parameters
----------
sms: StateMutableSequence
A :class:`~.StateMutableSequence` that should be interpolated
times: Union[datetime.datetime, List[datetime.datetime]]
a time, or a list of times for ``sms`` to be interpolated to.
Returns
-------
Union[StateMutableSequence, State]
If a single time is provided then a single state is returned. If a list of times is
provided then a :class:`~.StateMutableSequence` with the same type as ``sms`` is returned
Note
----
This function does **not** extrapolate. Times outside the range of the time range of ``sms``
Expand Down Expand Up @@ -101,6 +133,10 @@ def interpolate_state_mutable_sequence(sms: StateMutableSequence,
if len(times_to_interpolate) > 0:
# Only interpolate if required
state_vectors = StateVectors([state.state_vector for state in time_state_dict.values()])

# Needed for states with angles present
state_vectors = state_vectors.astype(float)

state_timestamps = [time.timestamp() for time in time_state_dict.keys()]
interp_timestamps = [time.timestamp() for time in times_to_interpolate]

Expand All @@ -110,9 +146,41 @@ def interpolate_state_mutable_sequence(sms: StateMutableSequence,
xp=state_timestamps,
fp=state_vectors[element_index, :])

retrieve_previous_state_fun = _get_previous_state(sms)
for state_index, time in enumerate(times_to_interpolate):
time_state_dict[time] = State(interp_output[:, state_index], timestamp=time)
original_state_before = retrieve_previous_state_fun(time)
time_state_dict[time] = original_state_before.from_state(
state=original_state_before,
timestamp=time,
state_vector=interp_output[:, state_index])

new_sms.states = [time_state_dict[time] for time in times]

return new_sms


def _get_previous_state(sms: StateMutableSequence) -> Callable[[datetime.datetime], State]:
"""This function produces a function which will return the state before a time in ``sms``.
Parameters
----------
sms: StateMutableSequence
A :class:`~.StateMutableSequence` to provide the states.
Returns
-------
Function
This function takes a :class:`datetime.datetime` and will return the State before that
time. If this inner function is called multiple times, the time must not decrease.
"""
state_iter = iter(pairwise(sms.states))
state_before, state_after = next(state_iter)

def inner_fun(t: datetime.datetime) -> State:
nonlocal state_before, state_after
while state_after.timestamp < t:
state_before, state_after = next(state_iter)
return state_before

return inner_fun
38 changes: 37 additions & 1 deletion stonesoup/functions/tests/test_interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest

from ..interpolate import time_range, interpolate_state_mutable_sequence
from ...types.state import State, StateMutableSequence
from ...types.state import State, StateMutableSequence, GaussianState, StateVector


@pytest.mark.parametrize("input_kwargs, expected",
Expand Down Expand Up @@ -107,3 +107,39 @@ def test_interpolate_error(gen_test_data):

with pytest.raises(IndexError):
_ = interpolate_state_mutable_sequence(sms, time)


def test_interpolate_state_other_properties():
float_times = [0, 0.1, 0.4, 0.9, 1.6, 2.5, 3.6, 4.9, 6.4, 8.1, 10]

sms = StateMutableSequence([GaussianState(state_vector=[t],
covar=[[t]],
timestamp=t0+datetime.timedelta(seconds=t))
for t in float_times
])

interp_float_times = [0, 2, 4, 6, 8, 10]
interp_datetime_times = [t0+datetime.timedelta(seconds=t) for t in interp_float_times]

new_sms = interpolate_state_mutable_sequence(sms, interp_datetime_times)

# Test state vector and times
for expected_value, state in zip(interp_float_times, new_sms.states):
assert state.timestamp == t0 + datetime.timedelta(seconds=expected_value)
# assert state.state_vector[0] == expected_value
np.testing.assert_allclose(state.state_vector, StateVector([expected_value]))
assert isinstance(state, GaussianState)

# Test Covariances
# Ideally the covariance should be the same as the state vector. However interpolating the
# covariance is not implemented and probably shouldn't be. Instead the current method uses the
# previous state’s covariance. In the future it may be better to use the closest state’s
# covariance. Therefore these tests are purposely quite lax and only check the covariance
# value is within a sensible range.

assert new_sms[0].covar[0][0] == 0 # t=0
assert 1.6 <= new_sms[1].covar[0][0] <= 2.5 # t=2
assert 3.6 <= new_sms[2].covar[0][0] <= 4.9 # t=4
assert 4.9 <= new_sms[3].covar[0][0] <= 6.4 # t=6
assert 6.4 <= new_sms[4].covar[0][0] <= 8.1 # t=8
assert new_sms[5].covar[0][0] == 10 # t=10
2 changes: 1 addition & 1 deletion stonesoup/measures/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __call__(self, state1, state2):
distance measure between a pair of input :class:`~.State` objects
"""
return NotImplementedError
raise NotImplementedError


class Euclidean(Measure):
Expand Down

0 comments on commit 58c41c4

Please sign in to comment.