diff --git a/perception/multi_object_tracker/include/multi_object_tracker/tracker/model/tracker_base.hpp b/perception/multi_object_tracker/include/multi_object_tracker/tracker/model/tracker_base.hpp index 95c23d8e42019..d2850793fa4bb 100644 --- a/perception/multi_object_tracker/include/multi_object_tracker/tracker/model/tracker_base.hpp +++ b/perception/multi_object_tracker/include/multi_object_tracker/tracker/model/tracker_base.hpp @@ -42,6 +42,8 @@ class Tracker { classification_ = classification; } + void updateClassification( + const std::vector & classification); private: unique_identifier_msgs::msg::UUID uuid_; diff --git a/perception/multi_object_tracker/src/tracker/model/multiple_vehicle_tracker.cpp b/perception/multi_object_tracker/src/tracker/model/multiple_vehicle_tracker.cpp index 51adca7e69b56..4f0fb4d7871f2 100644 --- a/perception/multi_object_tracker/src/tracker/model/multiple_vehicle_tracker.cpp +++ b/perception/multi_object_tracker/src/tracker/model/multiple_vehicle_tracker.cpp @@ -43,7 +43,7 @@ bool MultipleVehicleTracker::measure( big_vehicle_tracker_.measure(object, time, self_transform); normal_vehicle_tracker_.measure(object, time, self_transform); if (object_recognition_utils::getHighestProbLabel(object.classification) != Label::UNKNOWN) - setClassification(object.classification); + updateClassification(object.classification); return true; } diff --git a/perception/multi_object_tracker/src/tracker/model/pedestrian_and_bicycle_tracker.cpp b/perception/multi_object_tracker/src/tracker/model/pedestrian_and_bicycle_tracker.cpp index eed9d05359b77..d61a9a02ccd80 100644 --- a/perception/multi_object_tracker/src/tracker/model/pedestrian_and_bicycle_tracker.cpp +++ b/perception/multi_object_tracker/src/tracker/model/pedestrian_and_bicycle_tracker.cpp @@ -43,7 +43,7 @@ bool PedestrianAndBicycleTracker::measure( pedestrian_tracker_.measure(object, time, self_transform); bicycle_tracker_.measure(object, time, self_transform); if (object_recognition_utils::getHighestProbLabel(object.classification) != Label::UNKNOWN) - setClassification(object.classification); + updateClassification(object.classification); return true; } diff --git a/perception/multi_object_tracker/src/tracker/model/tracker_base.cpp b/perception/multi_object_tracker/src/tracker/model/tracker_base.cpp index ba684e4777947..a3320ff54afcb 100644 --- a/perception/multi_object_tracker/src/tracker/model/tracker_base.cpp +++ b/perception/multi_object_tracker/src/tracker/model/tracker_base.cpp @@ -54,6 +54,67 @@ bool Tracker::updateWithoutMeasurement() return true; } +void Tracker::updateClassification( + const std::vector & classification) +{ + // classification algorithm: + // 0. Normalize the input classification + // 1-1. Update the matched classification probability with a gain (ratio of 0.05) + // 1-2. If the label is not found, add it to the classification list + // 2. Remove the class with probability < remove_threshold (0.001) + // 3. Normalize tracking classification + + // Parameters + // if the remove_threshold is too high (compare to the gain), the classification will be removed + // immediately + const double gain = 0.05; + constexpr double remove_threshold = 0.001; + + // Normalization function + auto normalizeProbabilities = + [](std::vector & classification) { + double sum = 0.0; + for (const auto & class_ : classification) { + sum += class_.probability; + } + for (auto & class_ : classification) { + class_.probability /= sum; + } + }; + + // Normalize the input + auto classification_input = classification; + normalizeProbabilities(classification_input); + + // Update the matched classification probability with a gain + for (const auto & new_class : classification_input) { + bool found = false; + for (auto & old_class : classification_) { + if (new_class.label == old_class.label) { + old_class.probability += new_class.probability * gain; + found = true; + break; + } + } + // If the label is not found, add it to the classification list + if (!found) { + auto adding_class = new_class; + adding_class.probability *= gain; + classification_.push_back(adding_class); + } + } + + // If the probability is less than the threshold, remove the class + classification_.erase( + std::remove_if( + classification_.begin(), classification_.end(), + [remove_threshold](const auto & class_) { return class_.probability < remove_threshold; }), + classification_.end()); + + // Normalize tracking classification + normalizeProbabilities(classification_); +} + geometry_msgs::msg::PoseWithCovariance Tracker::getPoseWithCovariance( const rclcpp::Time & time) const {