Skip to content

Commit

Permalink
Merge pull request #1068 from kopytjuk/add-clear-mot-metrics
Browse files Browse the repository at this point in the history
Add CLEAR MOT `MetricGenerator`
  • Loading branch information
sdhiscocks authored Sep 10, 2024
2 parents 58c41c4 + 14b04e0 commit 3cd4a8c
Show file tree
Hide file tree
Showing 6 changed files with 449 additions and 9 deletions.
5 changes: 5 additions & 0 deletions docs/source/stonesoup.metricgenerator.clearmotmetrics.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
CLEAR MOT Metrics
=================

.. automodule:: stonesoup.metricgenerator.clearmotmetrics
:show-inheritance:
1 change: 1 addition & 0 deletions docs/source/stonesoup.metricgenerator.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Metric Generators
stonesoup.metricgenerator.pcrbmetric
stonesoup.metricgenerator.uncertaintymetric
stonesoup.metricgenerator.plotter
stonesoup.metricgenerator.clearmotmetrics

.. automodule:: stonesoup.metricgenerator
:no-members:
Expand Down
257 changes: 257 additions & 0 deletions stonesoup/metricgenerator/clearmotmetrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
import datetime
from collections import defaultdict
from typing import Dict, List, Set, Tuple, Union

from ..base import Property
from ..measures.state import Measure
from ..types.groundtruth import GroundTruthPath
from ..types.metric import Metric, TimeRangeMetric
from ..types.state import State
from ..types.time import TimeRange
from ..types.track import Track
from .base import MetricGenerator
from .manager import MultiManager

MatchSetAtTimestamp = Set[Tuple[str, str]] # tuples of (truth, track)
StatesFromTimeIdLookup = Dict[datetime.datetime, Dict[str, State]]


class ClearMotMetrics(MetricGenerator):
"""CLEAR MOT metrics
Computes multi-object tracking (MOT) metrics designed for the classification of events,
activities, and relationships (CLEAR) evaluation workshops. The implementation here
is derived from [1] and provides following metrics:
* MOTP (precision): average distance between all associated truth and track states.
The target score is 0.
* MOTA (accuracy): 1 - ratio of the number of misses, false positives, and mismatches
(ID-switches)relative to the total number of truth states. The target score is 1.
This score can become negative with a higher number of errors.
Reference:
[1] Evaluating Multiple Object Tracking Performance: The CLEAR MOT Metrics,
Bernardin et al, 2008
"""
tracks_key: str = Property(doc='Key to access set of tracks added to MetricManager',
default='tracks')
truths_key: str = Property(doc="Key to access set of ground truths added to MetricManager. "
"Or key to access a second set of tracks for track-to-track "
"metric generation",
default='groundtruth_paths')

distance_measure: Measure = Property(
doc="Distance measure used in calculating the MOTP score.")

def compute_metric(self, manager: MultiManager, **kwargs) -> List[Metric]:
"""Compute MOTP and MOTA metrics for a given time-period covered by truths and the tracks.
Parameters
----------
manager : MetricManager
containing the data to be used to create the metric(s)
Returns
-------
: list of :class:`~.Metric` objects
Generated metrics
"""

timestamps = manager.list_timestamps(generator=self)

motp_score, mota_score = self._compute_mota_and_motp(manager)

time_range = TimeRange(min(timestamps), max(timestamps))

motp = TimeRangeMetric(title="MOTP",
value=motp_score,
time_range=time_range,
generator=self)
mota = TimeRangeMetric(title="MOTA",
value=mota_score,
time_range=time_range,
generator=self)
return [motp, mota]

def _compute_mota_and_motp(self, manager: MultiManager) -> Tuple[float, float]:

matches_at_time_lookup = self._create_matches_at_time_lookup(manager)

check_matches_for_metric_calculation(matches_at_time_lookup)

truths_set = manager.states_sets[self.truths_key]
tracks_set = manager.states_sets[self.tracks_key]

truth_states_by_time_and_id: StatesFromTimeIdLookup = \
_create_state_from_time_and_id_lookup(truths_set)
track_states_by_time_and_id: StatesFromTimeIdLookup = \
_create_state_from_time_and_id_lookup(tracks_set)

# used for the MOTP (avg-distance over truth-track associations)
error_sum = 0.0
num_associated_truth_timestamps = 0

# used for the MOTA (1 - number-FPs, ID-changes etc. / number-GT-states)
num_misses, num_false_positives, num_miss_matches = 0, 0, 0

unique_timestamps = sorted(manager.list_timestamps(generator=self))

for i, timestamp in enumerate(unique_timestamps):

matches_current = matches_at_time_lookup[timestamp]

matched_truth_ids_curr = {match[0] for match in matches_current}
matched_tracks_at_timestamp = {match[1] for match in matches_current}

# update the variables for MOTP calculation
error_sum_in_timestep = self._compute_sum_of_distances_at_timestep(
truth_states_by_time_and_id, track_states_by_time_and_id, timestamp,
matches_current)
error_sum += error_sum_in_timestep
num_associated_truth_timestamps += len(matches_current)

truths_ids_at_timestamp = truth_states_by_time_and_id[timestamp].keys()
tracks_ids_at_timestamp = track_states_by_time_and_id[timestamp].keys()

unmatched_truth_ids = truths_ids_at_timestamp - matched_truth_ids_curr
unmatched_track_ids = tracks_ids_at_timestamp - matched_tracks_at_timestamp

# update counter variables used for MOTA
num_misses += len(unmatched_truth_ids)
num_false_positives += len(unmatched_track_ids)

if i > 0:
# for number of mis-matches (i.e. track ID changes for a single truth track)
matches_prev = matches_at_time_lookup[unique_timestamps[i - 1]]
num_miss_matches_current = self._compute_number_of_miss_matches_from_match_sets(
matches_prev, matches_current)
num_miss_matches += num_miss_matches_current

motp = (error_sum / num_associated_truth_timestamps) \
if num_associated_truth_timestamps > 0 else float("inf")

number_of_gt_states = self._compute_total_number_of_gt_states(manager)
mota = 1 - (num_misses + num_false_positives + num_miss_matches) / number_of_gt_states

return motp, mota

def _compute_sum_of_distances_at_timestep(self,
truth_states_by_time_id: StatesFromTimeIdLookup,
track_states_by_time_id: StatesFromTimeIdLookup,
timestamp: datetime.datetime,
matches_current: MatchSetAtTimestamp) -> float:
error_sum_in_timestep = 0.0
for match in matches_current:
truth_id = match[0]
track_id = match[1]

truth_state_at_t = truth_states_by_time_id[timestamp][truth_id]
track_state_at_t = track_states_by_time_id[timestamp][track_id]

error = self.distance_measure(truth_state_at_t, track_state_at_t)
error_sum_in_timestep += error
return error_sum_in_timestep

def _compute_total_number_of_gt_states(self, manager: MultiManager) -> int:
truth_state_set: Set[Track] = manager.states_sets[self.truths_key]
total_number_of_gt_states = sum(len(truth_track) for truth_track in truth_state_set)
return total_number_of_gt_states

def _create_matches_at_time_lookup(self, manager: MultiManager) \
-> Dict[datetime.datetime, MatchSetAtTimestamp]:
timestamps = manager.list_timestamps(generator=self)

matches_by_timestamp = defaultdict(set)

for i, timestamp in enumerate(timestamps):

associations = manager.association_set.associations_at_timestamp(timestamp)

for association in associations:
truth, track = self.truth_track_from_association(association)
match_truth_track = (truth.id, track.id)
matches_by_timestamp[timestamp].add(match_truth_track)
return matches_by_timestamp

def _compute_number_of_miss_matches_from_match_sets(self,
matches_prev: MatchSetAtTimestamp,
matches_current: MatchSetAtTimestamp)\
-> int:
num_miss_matches_current = 0

matched_truth_ids_prev = {match[0] for match in matches_prev}
matched_truth_ids_curr = {match[0] for match in matches_current}
truths_ids_at_both_timestamps = matched_truth_ids_prev & matched_truth_ids_curr

for truth_id in truths_ids_at_both_timestamps:
matched_track_id_prev = next(
match[1] for match in matches_prev if match[0] == truth_id)
matched_track_id_curr = next(
match[1] for match in matches_current if match[0] == truth_id)

if matched_track_id_prev != matched_track_id_curr:
num_miss_matches_current += 1
return num_miss_matches_current

@staticmethod
def truth_track_from_association(association) -> Tuple[Track, Track]:
"""Find truth and track from an association.
Parameters
----------
association: Association
Association that contains truth and track as its objects
Returns
-------
GroundTruthPath, Track
True object and track that are the objects of the `association`
"""
truth, track = association.objects
# Sets aren't ordered, so need to ensure correct path is truth/track
if isinstance(truth, Track):
truth, track = track, truth
return truth, track


def _create_state_from_time_and_id_lookup(tracks_set: Set[Union[Track, GroundTruthPath]]) \
-> StatesFromTimeIdLookup:
track_states_by_time_id: StatesFromTimeIdLookup = defaultdict(dict)
for track in tracks_set:
for state in track.last_timestamp_generator():
track_states_by_time_id[state.timestamp][track.id] = state
return track_states_by_time_id


class AssociationSetNotValid(Exception):
pass


def check_matches_for_metric_calculation(matches_by_timestamp:
Dict[datetime.datetime, MatchSetAtTimestamp]):
"""Checks the matches prior to computing CLEAR MOT metrics. If this function returns
without raising an exception, it is checked that a single track is associated with one truth
(one-2-one relationship) at a given timestep and vice versa.
Parameters
----------
matches_by_timestamp: Dict[datetime.datetime, MatchSetAtTimestamp]
Dictionary which returns a set of (truth, track) matches for a given timestamp.
Raises
------
AssociationSetNotValid
"""

for t, matches in matches_by_timestamp.items():
truth_ids = [m[0] for m in matches]
if len(truth_ids) > len(set(truth_ids)):
raise AssociationSetNotValid(f"Multiple tracks are assigned with "
f"the same truth track at time {t}!"
" Resolve this ambiguity in order to continue.")

track_ids = [m[1] for m in matches]
if len(track_ids) > len(set(track_ids)):
raise AssociationSetNotValid(f"A single track is assigned with "
f"multiple truth tracks at time {t}!"
" Resolve this ambiguity in order to continue.")
25 changes: 18 additions & 7 deletions stonesoup/metricgenerator/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,37 @@
import pytest

from ...metricgenerator.manager import MultiManager
from ...types.association import TimeRangeAssociation, AssociationSet
from ...models.measurement.linear import LinearGaussian
from ...models.transition.linear import (
CombinedLinearGaussianTransitionModel,
ConstantVelocity,
)
from ...types.array import CovarianceMatrix, StateVector
from ...types.association import AssociationSet, TimeRangeAssociation
from ...types.detection import Detection
from ...types.groundtruth import GroundTruthPath, GroundTruthState
from ...types.hypothesis import SingleDistanceHypothesis
from ...types.prediction import GaussianStatePrediction
from ...types.time import TimeRange
from ...types.track import Track
from ...types.update import GaussianStateUpdate
from ...types.array import CovarianceMatrix, StateVector
from ...models.transition.linear import CombinedLinearGaussianTransitionModel, ConstantVelocity
from ...models.measurement.linear import LinearGaussian


@pytest.fixture
def time_period() -> timedelta:
return timedelta(seconds=1)


@pytest.fixture()
def trial_timestamps():
now = datetime.now()
return [now + timedelta(seconds=i) for i in range(4)]
def trial_timestamps(time_period: timedelta):
now = datetime(2024, 1, 1, 0, 0, 0)
return [now + i*time_period for i in range(4)]


@pytest.fixture()
def trial_truths(trial_timestamps):
return [
# object moving from (x=0, y=0) to (x=3, y=3) with (vx=1, vy=1)
GroundTruthPath([
GroundTruthState(np.array([[0, 1, 0, 1]]), timestamp=trial_timestamps[0],
metadata={"colour": "red"}),
Expand All @@ -36,6 +45,7 @@ def trial_truths(trial_timestamps):
GroundTruthState(np.array([[3, 1, 3, 1]]), timestamp=trial_timestamps[3],
metadata={"colour": "red"})
]),
# object moving from (x=-2, y=-2) to (x=2, y=2) with (vx=1, vy=1)
GroundTruthPath([
GroundTruthState(np.array([[-2, 1, -2, 1]]), timestamp=trial_timestamps[0],
metadata={"colour": "green"}),
Expand All @@ -46,6 +56,7 @@ def trial_truths(trial_timestamps):
GroundTruthState(np.array([[2, 1, 2, 1]]), timestamp=trial_timestamps[3],
metadata={"colour": "green"})
]),
# object moving from (x=--1, y=1) to (x=3, y=3) with (vx=1, vy=0)
GroundTruthPath([
GroundTruthState(np.array([[-1, 1, 1, 0]]), timestamp=trial_timestamps[0],
metadata={"colour": "blue"}),
Expand Down
Loading

0 comments on commit 3cd4a8c

Please sign in to comment.