diff --git a/stonesoup/tracker/base.py b/stonesoup/tracker/base.py index 8d63387b8..7308d4459 100644 --- a/stonesoup/tracker/base.py +++ b/stonesoup/tracker/base.py @@ -1,6 +1,10 @@ +import datetime from abc import abstractmethod +from typing import Iterator, Set, Tuple from ..base import Base +from ..types.detection import Detection +from ..types.track import Track class Tracker(Base): @@ -8,14 +12,14 @@ class Tracker(Base): @property @abstractmethod - def tracks(self): + def tracks(self) -> Set[Track]: raise NotImplementedError - def __iter__(self): + def __iter__(self) -> Iterator[Tuple[datetime.datetime, Set[Track]]]: return self @abstractmethod - def __next__(self): + def __next__(self) -> Tuple[datetime.datetime, Set[Track]]: """ Returns ------- @@ -25,3 +29,48 @@ def __next__(self): Tracks existing in the time step """ raise NotImplementedError + + +class _TrackerMixInBase(Base): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.detector_iter = None + + def __iter__(self) -> Iterator[Tuple[datetime.datetime, Set[Track]]]: + if self.detector is None: + raise AttributeError("Detector has not been set. A detector attribute is required to " + "iterate over a tracker.") + if self.detector_iter is None: + self.detector_iter = iter(self.detector) + + return super().__iter__() + + +class _TrackerMixInNext(_TrackerMixInBase): + """ The tracking logic is contained within the __next__ method.""" + + @abstractmethod + def __next__(self) -> Tuple[datetime.datetime, Set[Track]]: + """Pull detections from the detector (`detector_iter`). Act on them to create tracks.""" + + def update_tracker(self, time: datetime.datetime, detections: Set[Detection]) \ + -> Tuple[datetime.datetime, Set[Track]]: + + placeholder_detector_iter = self.detector_iter + self.detector_iter = iter([(time, detections)]) + tracker_output = next(self) + self.detector_iter = placeholder_detector_iter + return tracker_output + + +class _TrackerMixInUpdate(_TrackerMixInBase): + """ The tracking logic is contained within the update_tracker function.""" + + def __next__(self) -> Tuple[datetime.datetime, Set[Track]]: + time, detections = next(self.detector_iter) + return self.update_tracker(time, detections) + + @abstractmethod + def update_tracker(self, time: datetime.datetime, detections: Set[Detection]) \ + -> Tuple[datetime.datetime, Set[Track]]: + """Use `time` and `detections` to create tracks.""" diff --git a/stonesoup/tracker/pointprocess.py b/stonesoup/tracker/pointprocess.py index 137ab5679..a3b1d9f47 100644 --- a/stonesoup/tracker/pointprocess.py +++ b/stonesoup/tracker/pointprocess.py @@ -1,16 +1,19 @@ -from .base import Tracker +import datetime +from typing import Tuple, Set + +from .base import Tracker, _TrackerMixInNext from ..base import Property +from ..hypothesiser.gaussianmixture import GaussianMixtureHypothesiser +from ..mixturereducer.gaussianmixture import GaussianMixtureReducer from ..reader import DetectionReader -from ..types.state import TaggedWeightedGaussianState from ..types.mixture import GaussianMixture from ..types.numeric import Probability +from ..types.state import TaggedWeightedGaussianState from ..types.track import Track from ..updater import Updater -from ..hypothesiser.gaussianmixture import GaussianMixtureHypothesiser -from ..mixturereducer.gaussianmixture import GaussianMixtureReducer -class PointProcessMultiTargetTracker(Tracker): +class PointProcessMultiTargetTracker(_TrackerMixInNext, Tracker): """ Base class for Gaussian Mixture (GM) style implementations of point process derived filters @@ -40,16 +43,12 @@ def __init__(self, *args, **kwargs): self.gaussian_mixture = GaussianMixture() @property - def tracks(self): + def tracks(self) -> Set[Track]: tracks = set() for track in self.target_tracks.values(): tracks.add(track) return tracks - def __iter__(self): - self.detector_iter = iter(self.detector) - return super().__iter__() - def update_tracks(self): """ Updates the tracks (:class:`Track`) associated with the filter. @@ -77,7 +76,7 @@ def update_tracks(self): self.extraction_threshold: self.target_tracks[tag] = Track([component], id=tag) - def __next__(self): + def __next__(self) -> Tuple[datetime.datetime, Set[Track]]: time, detections = next(self.detector_iter) # Add birth component self.birth_component.timestamp = time diff --git a/stonesoup/tracker/simple.py b/stonesoup/tracker/simple.py index e1fd5e123..95919f871 100644 --- a/stonesoup/tracker/simple.py +++ b/stonesoup/tracker/simple.py @@ -1,19 +1,23 @@ +import datetime +from typing import Set, Tuple + import numpy as np -from .base import Tracker +from .base import Tracker, _TrackerMixInNext from ..base import Property from ..dataassociator import DataAssociator from ..deleter import Deleter -from ..reader import DetectionReader +from ..functions import gm_reduce_single from ..initiator import Initiator -from ..updater import Updater +from ..reader import DetectionReader from ..types.array import StateVectors from ..types.prediction import GaussianStatePrediction +from ..types.track import Track from ..types.update import GaussianStateUpdate -from ..functions import gm_reduce_single +from ..updater import Updater -class SingleTargetTracker(Tracker): +class SingleTargetTracker(_TrackerMixInNext, Tracker): """A simple single target tracker. Track a single object using Stone Soup components. The tracker works by @@ -46,14 +50,10 @@ def __init__(self, *args, **kwargs): self._track = None @property - def tracks(self): + def tracks(self) -> Set[Track]: return {self._track} if self._track else set() - def __iter__(self): - self.detector_iter = iter(self.detector) - return super().__iter__() - - def __next__(self): + def __next__(self) -> Tuple[datetime.datetime, Set[Track]]: time, detections = next(self.detector_iter) if self._track is not None: associations = self.data_associator.associate( @@ -75,7 +75,7 @@ def __next__(self): return time, self.tracks -class SingleTargetMixtureTracker(Tracker): +class SingleTargetMixtureTracker(_TrackerMixInNext, Tracker): """ A simple single target tracking that receives associations from a (Gaussian) Mixture associator. @@ -104,14 +104,10 @@ def __init__(self, *args, **kwargs): self._track = None @property - def tracks(self): + def tracks(self) -> Set[Track]: return {self._track} if self._track else set() - def __iter__(self): - self.detector_iter = iter(self.detector) - return super().__iter__() - - def __next__(self): + def __next__(self) -> Tuple[datetime.datetime, Set[Track]]: time, detections = next(self.detector_iter) if self._track is not None: @@ -177,7 +173,7 @@ def __next__(self): return time, self.tracks -class MultiTargetTracker(Tracker): +class MultiTargetTracker(_TrackerMixInNext, Tracker): """A simple multi target tracker. Track multiple objects using Stone Soup components. The tracker works by @@ -203,14 +199,10 @@ def __init__(self, *args, **kwargs): self._tracks = set() @property - def tracks(self): + def tracks(self) -> Set[Track]: return self._tracks - def __iter__(self): - self.detector_iter = iter(self.detector) - return super().__iter__() - - def __next__(self): + def __next__(self) -> Tuple[datetime.datetime, Set[Track]]: time, detections = next(self.detector_iter) associations = self.data_associator.associate( @@ -231,7 +223,7 @@ def __next__(self): return time, self.tracks -class MultiTargetMixtureTracker(Tracker): +class MultiTargetMixtureTracker(_TrackerMixInNext, Tracker): """A simple multi target tracker that receives associations from a (Gaussian) Mixture associator. @@ -259,14 +251,10 @@ def __init__(self, *args, **kwargs): self._tracks = set() @property - def tracks(self): + def tracks(self) -> Set[Track]: return self._tracks - def __iter__(self): - self.detector_iter = iter(self.detector) - return super().__iter__() - - def __next__(self): + def __next__(self) -> Tuple[datetime.datetime, Set[Track]]: time, detections = next(self.detector_iter) associations = self.data_associator.associate( diff --git a/stonesoup/tracker/tests/test_base.py b/stonesoup/tracker/tests/test_base.py new file mode 100644 index 000000000..353027a04 --- /dev/null +++ b/stonesoup/tracker/tests/test_base.py @@ -0,0 +1,184 @@ +import datetime +import heapq +from typing import Tuple, Set, List + +import pytest + +from ..base import Tracker, _TrackerMixInUpdate, _TrackerMixInNext +from ...base import Property +from ...types.detection import Detection +from ...types.track import Track + + +class TrackerNextWithoutDetector(_TrackerMixInNext, Tracker): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._tracks = set() + + def __next__(self) -> Tuple[datetime.datetime, Set[Track]]: + time, detections = next(self.detector_iter) + self._tracks = {Track(detection) for detection in detections} + return time, self.tracks + + @property + def tracks(self): + return self._tracks + + +class TrackerNextWithDetector(TrackerNextWithoutDetector): + detector: list = Property(default=[]) + + +class TrackerUpdateWithoutDetector(_TrackerMixInUpdate, Tracker): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._tracks = set() + + def update_tracker(self, time: datetime.datetime, detections: Set[Detection]) \ + -> Tuple[datetime.datetime, Set[Track]]: + + self._tracks = {Track(detection) for detection in detections} + return time, self.tracks + + @property + def tracks(self): + return self._tracks + + +class TrackerUpdateWithDetector(TrackerUpdateWithoutDetector): + detector: list = Property(default=[]) + + +@pytest.fixture +def detector() -> List[Tuple[datetime.datetime, Set[Detection]]]: + detections = [ + Detection(timestamp=datetime.datetime(2023, 11, i), state_vector=[i]) + for i in range(1, 10) + ] + + detector = [(det.timestamp, {det}) for det in detections] + + return detector + + +@pytest.mark.parametrize("tracker_class", + [TrackerNextWithoutDetector, TrackerNextWithDetector, + TrackerUpdateWithoutDetector, TrackerUpdateWithDetector]) +def test_tracker_update_tracker(tracker_class, detector): + tracker = tracker_class() + for input_time, detections in detector: + time, tracks = tracker.update_tracker(input_time, detections) + + assert time == input_time + assert tracks == tracker.tracks + tracks_state = {track.state for track in tracks} + assert tracks_state == detections + + +@pytest.mark.parametrize("tracker_class", + [TrackerNextWithoutDetector, TrackerUpdateWithoutDetector]) +def test_tracker_without_detector_iter_error(tracker_class): + tracker_without_detector = tracker_class() + with pytest.raises(AttributeError): + iter(tracker_without_detector) + + with pytest.raises(TypeError): + next(tracker_without_detector) + + +@pytest.mark.parametrize("tracker_class", [TrackerNextWithDetector, TrackerUpdateWithDetector]) +def test_tracker_detector_none_iter_error(tracker_class): + tracker = tracker_class(detector=None) + with pytest.raises(AttributeError): + iter(tracker) + + +@pytest.mark.parametrize("tracker_class", [TrackerNextWithDetector, TrackerUpdateWithDetector]) +def test_tracker_with_detector_iter(tracker_class): + tracker = tracker_class() + assert iter(tracker) is tracker + assert tracker.detector_iter is not None + + with pytest.raises(StopIteration): + next(tracker) + + +@pytest.mark.parametrize("tracker_class", [TrackerNextWithDetector, TrackerUpdateWithDetector]) +def test_tracker_with_detector_for_loop(tracker_class, detector): + tracker = tracker_class(detector=detector) + + for (tracker_time, tracks), (detect_time, detections) in zip(tracker, detector): + assert tracker_time == detect_time + assert tracks == tracker.tracks + tracks_state = {track.state for track in tracks} + assert tracks_state == detections + + +@pytest.mark.parametrize("tracker_class", [TrackerNextWithDetector, TrackerUpdateWithDetector]) +def test_tracker_with_detector_next(tracker_class, detector): + tracker = tracker_class(detector=detector) + assert iter(tracker) is tracker + + for detect_time, detections in detector: + tracker_time, tracks = next(tracker) + assert tracker_time == detect_time + assert tracks == tracker.tracks + tracks_state = {track.state for track in tracks} + assert tracks_state == detections + + with pytest.raises(StopIteration): + _ = next(tracker) + + +@pytest.mark.parametrize("tracker_class", [TrackerNextWithDetector, TrackerUpdateWithDetector]) +def test_tracker_wont_restart(tracker_class, detector): + tracker = tracker_class(detector=detector) + for _ in tracker: + pass + + iter(tracker) + with pytest.raises(StopIteration): + next(tracker) + + +@pytest.mark.parametrize("tracker_class", [TrackerNextWithDetector, TrackerUpdateWithDetector]) +def test_heapq_merge_with_tracker(tracker_class, detector): + merge_output = list(heapq.merge(tracker_class(detector=detector), + tracker_class(detector=detector))) + + assert len(merge_output) == len(detector)*2 + + for idx, (tracker_time, tracks) in enumerate(merge_output): + detect_time, detections = detector[int(idx/2)] + assert tracker_time == detect_time + tracks_state = {track.state for track in tracks} + assert tracks_state == detections + + +@pytest.mark.parametrize("tracker_class", + [TrackerNextWithoutDetector, TrackerNextWithDetector, + TrackerUpdateWithoutDetector, TrackerUpdateWithDetector]) +def test_tracker_detector_iter_creation(tracker_class): + tracker_without_detector = tracker_class() + assert tracker_without_detector.detector_iter is None + + +@pytest.mark.parametrize("tracker_class", [TrackerNextWithDetector, TrackerUpdateWithDetector]) +def test_tracker_with_detections_mid_iter(tracker_class, detector): + tracker = tracker_class(detector=detector) + for i, ((tracker_time, tracks), (detect_time, detections)) in enumerate(zip(tracker, + detector)): + assert tracker_time == detect_time + assert tracks == tracker.tracks + tracks_state = {track.state for track in tracks} + assert tracks_state == detections + + interrupt_time = datetime.datetime(2024, 4, 1, i) + interrupt_detections = {Detection(timestamp=interrupt_time, state_vector=[i])} + time, interrupt_tracks = tracker.update_tracker(interrupt_time, interrupt_detections) + assert time == interrupt_time + assert interrupt_tracks == tracker.tracks + interrupt_tracks_state = {track.state for track in interrupt_tracks} + assert interrupt_tracks_state == interrupt_detections