Skip to content

Commit

Permalink
18.5% performance improvement in MPPI
Browse files Browse the repository at this point in the history
  • Loading branch information
SteveMacenski committed Mar 8, 2024
1 parent 2636371 commit 7a0e651
Show file tree
Hide file tree
Showing 14 changed files with 116 additions and 117 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ namespace mppi
class CriticManager
{
public:
typedef std::vector<std::unique_ptr<critics::CriticFunction>> Critics;
/**
* @brief Constructor for mppi::CriticManager
*/
Expand Down Expand Up @@ -94,7 +95,7 @@ class CriticManager
ParametersHandler * parameters_handler_;
std::vector<std::string> critic_names_;
std::unique_ptr<pluginlib::ClassLoader<critics::CriticFunction>> loader_;
std::vector<std::unique_ptr<critics::CriticFunction>> critics_;
Critics critics_;

rclcpp::Logger logger_{rclcpp::get_logger("MPPIController")};
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,6 @@ class CostCritic : public CriticFunction
*/
bool inCollision(float cost, float x, float y, float theta);

/**
* @brief cost at a robot pose
* @param x X of pose
* @param y Y of pose
* @return Collision information at pose
*/
float costAtPose(float x, float y);

/**
* @brief Find the min cost of the inflation decay function for which the robot MAY be
* in collision in any orientation
Expand All @@ -81,6 +73,7 @@ class CostCritic : public CriticFunction
float possible_collision_cost_;

bool consider_footprint_{true};
bool is_tracking_unknown_{true};
float circumscribed_radius_{0};
float circumscribed_cost_{0};
float collision_cost_{0};
Expand Down
30 changes: 16 additions & 14 deletions nav2_mppi_controller/include/nav2_mppi_controller/tools/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@
#include "builtin_interfaces/msg/time.hpp"
#include "nav2_mppi_controller/critic_data.hpp"

#define M_PIF 3.141592653589793238462643383279502884e+00F
#define M_PIF_2 1.5707963267948966e+00F

namespace mppi::utils
{
using xt::evaluation_strategy::immediate;
Expand Down Expand Up @@ -259,16 +262,16 @@ inline bool withinPositionGoalTolerance(

/**
* @brief normalize
* Normalizes the angle to be -M_PI circle to +M_PI circle
* Normalizes the angle to be -M_PIF circle to +M_PIF circle
* It takes and returns radians.
* @param angles Angles to normalize
* @return normalized angles
*/
template<typename T>
auto normalize_angles(const T & angles)
{
auto && theta = xt::eval(xt::fmod(angles + M_PI, 2.0 * M_PI));
return xt::eval(xt::where(theta <= 0.0, theta + M_PI, theta - M_PI));
auto && theta = xt::eval(xt::fmod(angles + M_PIF, 2.0f * M_PIF));
return xt::eval(xt::where(theta <= 0.0f, theta + M_PIF, theta - M_PIF));
}

/**
Expand Down Expand Up @@ -310,13 +313,12 @@ inline size_t findPathFurthestReachedPoint(const CriticData & data)

size_t max_id_by_trajectories = 0, min_id_by_path = 0;
float min_distance_by_path = std::numeric_limits<float>::max();
float cur_dist = 0.0f;

for (size_t i = 0; i < dists.shape(0); i++) {
min_id_by_path = 0;
min_distance_by_path = std::numeric_limits<float>::max();
for (size_t j = 0; j < dists.shape(1); j++) {
cur_dist = dists(i, j);
for (size_t j = max_id_by_trajectories; j < dists.shape(1); j++) {
const float & cur_dist = dists(i, j);
if (cur_dist < min_distance_by_path) {
min_distance_by_path = cur_dist;
min_id_by_path = j;
Expand Down Expand Up @@ -436,7 +438,7 @@ inline float posePointAngle(
if (!forward_preference) {
return std::min(
fabs(angles::shortest_angular_distance(yaw, pose_yaw)),
fabs(angles::shortest_angular_distance(yaw, angles::normalize_angle(pose_yaw + M_PI))));
fabs(angles::shortest_angular_distance(yaw, angles::normalize_angle(pose_yaw + M_PIF))));
}

return fabs(angles::shortest_angular_distance(yaw, pose_yaw));
Expand All @@ -454,14 +456,14 @@ inline float posePointAngle(
const geometry_msgs::msg::Pose & pose,
double point_x, double point_y, double point_yaw)
{
float pose_x = pose.position.x;
float pose_y = pose.position.y;
float pose_yaw = tf2::getYaw(pose.orientation);
float pose_x = static_cast<float>(pose.position.x);
float pose_y = static_cast<float>(pose.position.y);
float pose_yaw = static_cast<float>(tf2::getYaw(pose.orientation));

float yaw = atan2f(point_y - pose_y, point_x - pose_x);
float yaw = atan2f(static_cast<float>(point_y) - pose_y, static_cast<float>(point_x) - pose_x);

if (fabs(angles::shortest_angular_distance(yaw, point_yaw)) > M_PI_2) {
yaw = angles::normalize_angle(yaw + M_PI);
if (fabs(angles::shortest_angular_distance(yaw, static_cast<float>(point_yaw))) > M_PIF_2) {
yaw = angles::normalize_angle(yaw + M_PIF);
}

return fabs(angles::shortest_angular_distance(yaw, pose_yaw));
Expand Down Expand Up @@ -664,7 +666,7 @@ inline unsigned int findFirstPathInversion(nav_msgs::msg::Path & path)

// Checking for the existance of cusp, in the path, using the dot product.
float dot_product = (oa_x * ab_x) + (oa_y * ab_y);
if (dot_product < 0.0) {
if (dot_product < 0.0f) {
return idx + 1;
}
}
Expand Down
4 changes: 2 additions & 2 deletions nav2_mppi_controller/src/critic_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,11 @@ std::string CriticManager::getFullName(const std::string & name)
void CriticManager::evalTrajectoriesScores(
CriticData & data) const
{
for (size_t q = 0; q < critics_.size(); q++) {
for (const auto & critic : critics_) {
if (data.fail_flag) {
break;
}
critics_[q]->score(data);
critic->score(data);
}
}

Expand Down
50 changes: 32 additions & 18 deletions nav2_mppi_controller/src/critics/constraint_critic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,17 @@ void ConstraintCritic::initialize()
auto getParentParam = parameters_handler_->getParamGetter(parent_name_);

getParam(power_, "cost_power", 1);
getParam(weight_, "cost_weight", 4.0);
getParam(weight_, "cost_weight", 4.0f);
RCLCPP_INFO(
logger_, "ConstraintCritic instantiated with %d power and %f weight.",
power_, weight_);

float vx_max, vy_max, vx_min;
getParentParam(vx_max, "vx_max", 0.5);
getParentParam(vy_max, "vy_max", 0.0);
getParentParam(vx_min, "vx_min", -0.35);
getParentParam(vx_max, "vx_max", 0.5f);
getParentParam(vy_max, "vy_max", 0.0f);
getParentParam(vx_min, "vx_min", -0.35f);

const float min_sgn = vx_min > 0.0 ? 1.0 : -1.0;
const float min_sgn = vx_min > 0.0f ? 1.0f : -1.0f;
max_vel_ = sqrtf(vx_max * vx_max + vy_max * vy_max);
min_vel_ = min_sgn * sqrtf(vx_min * vx_min + vy_max * vy_max);
}
Expand All @@ -46,32 +46,46 @@ void ConstraintCritic::score(CriticData & data)
return;
}

auto sgn = xt::where(data.state.vx > 0.0, 1.0, -1.0);
auto vel_total = sgn * xt::sqrt(data.state.vx * data.state.vx + data.state.vy * data.state.vy);
auto out_of_max_bounds_motion = xt::maximum(vel_total - max_vel_, 0);
auto out_of_min_bounds_motion = xt::maximum(min_vel_ - vel_total, 0);
// Differential motion model
auto diff = dynamic_cast<DiffDriveMotionModel *>(data.motion_model.get());
if (diff != nullptr) {
data.costs += xt::pow(
xt::sum(
(std::move(xt::maximum(data.state.vx - max_vel_, 0.0f)) +
std::move(xt::maximum(min_vel_ - data.state.vx, 0.0f))) *
data.model_dt, {1}, immediate) * weight_, power_);
return;
}

// Omnidirectional motion model
auto omni = dynamic_cast<OmniMotionModel *>(data.motion_model.get());
if (omni != nullptr) {
auto sgn = xt::eval(xt::where(data.state.vx > 0.0f, 1.0f, -1.0f));
auto vel_total = sgn * xt::hypot(data.state.vx, data.state.vy);
data.costs += xt::pow(
xt::sum(
(std::move(xt::maximum(vel_total - max_vel_, 0.0f)) +
std::move(xt::maximum(min_vel_ - vel_total, 0.0f))) *
data.model_dt, {1}, immediate) * weight_, power_);
return;
}

// Ackermann motion model
auto acker = dynamic_cast<AckermannMotionModel *>(data.motion_model.get());
if (acker != nullptr) {
auto & vx = data.state.vx;
auto & wz = data.state.wz;
auto out_of_turning_rad_motion = xt::maximum(
acker->getMinTurningRadius() - (xt::fabs(vx) / xt::fabs(wz)), 0.0);
acker->getMinTurningRadius() - (xt::fabs(vx) / xt::fabs(wz)), 0.0f);

data.costs += xt::pow(
xt::sum(
(std::move(out_of_max_bounds_motion) +
std::move(out_of_min_bounds_motion) +
(std::move(xt::maximum(data.state.vx - max_vel_, 0.0f)) +
std::move(xt::maximum(min_vel_ - data.state.vx, 0.0f)) +
std::move(out_of_turning_rad_motion)) *
data.model_dt, {1}, immediate) * weight_, power_);
return;
}

data.costs += xt::pow(
xt::sum(
(std::move(out_of_max_bounds_motion) +
std::move(out_of_min_bounds_motion)) *
data.model_dt, {1}, immediate) * weight_, power_);
}

} // namespace mppi::critics
Expand Down
59 changes: 28 additions & 31 deletions nav2_mppi_controller/src/critics/cost_critic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ void CostCritic::initialize()
auto getParam = parameters_handler_->getParamGetter(name_);
getParam(consider_footprint_, "consider_footprint", false);
getParam(power_, "cost_power", 1);
getParam(weight_, "cost_weight", 3.81);
getParam(critical_cost_, "critical_cost", 300.0);
getParam(collision_cost_, "collision_cost", 1000000.0);
getParam(near_goal_distance_, "near_goal_distance", 0.5);
getParam(weight_, "cost_weight", 3.81f);
getParam(critical_cost_, "critical_cost", 300.0f);
getParam(collision_cost_, "collision_cost", 1000000.0f);
getParam(near_goal_distance_, "near_goal_distance", 0.5f);
getParam(inflation_layer_name_, "inflation_layer_name", std::string(""));

// Normalized by cost value to put in same regime as other weights
Expand Down Expand Up @@ -94,6 +94,8 @@ void CostCritic::score(CriticData & data)
return;
}

is_tracking_unknown_ = costmap_ros_->getLayeredCostmap()->isTrackingUnknown();

if (consider_footprint_) {
// footprint may have changed since initialization if user has dynamic footprints
possible_collision_cost_ = findCircumscribedCost(costmap_ros_);
Expand All @@ -106,32 +108,42 @@ void CostCritic::score(CriticData & data)
}

auto && repulsive_cost = xt::xtensor<float, 1>::from_shape({data.costs.shape(0)});
repulsive_cost.fill(0.0);
repulsive_cost.fill(0.0f);

unsigned int x_i = 0u, y_i = 0u;
float pose_cost;
auto costmap = collision_checker_.getCostmap();

const size_t traj_len = data.trajectories.x.shape(1);
bool all_trajectories_collide = true;
for (size_t i = 0; i < data.trajectories.x.shape(0); ++i) {
bool trajectory_collide = false;
const auto & traj = data.trajectories;
float pose_cost;
const auto & traj_x = xt::view(data.trajectories.x, i, xt::all());
const auto & traj_y = xt::view(data.trajectories.y, i, xt::all());
const auto & traj_yaw = xt::view(data.trajectories.yaws, i, xt::all());
pose_cost = 0.0f;

for (size_t j = 0; j < traj_len; j++) {
// The costAtPose doesn't use orientation
// The getCost doesn't use orientation
// The footprintCostAtPose will always return "INSCRIBED" if footprint is over it
// So the center point has more information than the footprint
pose_cost = costAtPose(traj.x(i, j), traj.y(i, j));
if (!costmap->worldToMap(traj_x(j), traj_y(j), x_i, y_i)) {
pose_cost = static_cast<float>(nav2_costmap_2d::NO_INFORMATION);
} else {
pose_cost = static_cast<float>(costmap->getCost(x_i, y_i));
}

if (pose_cost < 1.0f) {continue;} // In free space

if (inCollision(pose_cost, traj.x(i, j), traj.y(i, j), traj.yaws(i, j))) {
if (inCollision(pose_cost, traj_x(j), traj_y(j), traj_yaw(j))) {
trajectory_collide = true;
break;
}

// Let near-collision trajectory points be punished severely
// Note that we collision check based on the footprint actual,
// but score based on the center-point cost regardless
using namespace nav2_costmap_2d; // NOLINT
if (pose_cost >= INSCRIBED_INFLATED_OBSTACLE) {
if (pose_cost >= nav2_costmap_2d::INSCRIBED_INFLATED_OBSTACLE) {
repulsive_cost[i] += critical_cost_;
} else if (!near_goal) { // Generally prefer trajectories further from obstacles
repulsive_cost[i] += pose_cost;
Expand All @@ -156,9 +168,6 @@ void CostCritic::score(CriticData & data)
*/
bool CostCritic::inCollision(float cost, float x, float y, float theta)
{
bool is_tracking_unknown =
costmap_ros_->getLayeredCostmap()->isTrackingUnknown();

// If consider_footprint_ check footprint scort for collision
if (consider_footprint_ &&
(cost >= possible_collision_cost_ || possible_collision_cost_ < 1.0f))
Expand All @@ -168,29 +177,17 @@ bool CostCritic::inCollision(float cost, float x, float y, float theta)
}

switch (static_cast<unsigned char>(cost)) {
using namespace nav2_costmap_2d; // NOLINT
case (LETHAL_OBSTACLE):
case (nav2_costmap_2d::LETHAL_OBSTACLE):
return true;
case (INSCRIBED_INFLATED_OBSTACLE):
case (nav2_costmap_2d::INSCRIBED_INFLATED_OBSTACLE):
return consider_footprint_ ? false : true;
case (NO_INFORMATION):
return is_tracking_unknown ? false : true;
case (nav2_costmap_2d::NO_INFORMATION):
return is_tracking_unknown_ ? false : true;
}

return false;
}

float CostCritic::costAtPose(float x, float y)
{
using namespace nav2_costmap_2d; // NOLINT
unsigned int x_i, y_i;
if (!collision_checker_.worldToMap(x, y, x_i, y_i)) {
return nav2_costmap_2d::NO_INFORMATION;
}

return collision_checker_.pointCost(x_i, y_i);
}

} // namespace mppi::critics

#include <pluginlib/class_list_macros.hpp>
Expand Down
4 changes: 2 additions & 2 deletions nav2_mppi_controller/src/critics/goal_angle_critic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ void GoalAngleCritic::initialize()
auto getParam = parameters_handler_->getParamGetter(name_);

getParam(power_, "cost_power", 1);
getParam(weight_, "cost_weight", 3.0);
getParam(weight_, "cost_weight", 3.0f);

getParam(threshold_to_consider_, "threshold_to_consider", 0.5);
getParam(threshold_to_consider_, "threshold_to_consider", 0.5f);

RCLCPP_INFO(
logger_,
Expand Down
12 changes: 5 additions & 7 deletions nav2_mppi_controller/src/critics/goal_critic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ void GoalCritic::initialize()
auto getParam = parameters_handler_->getParamGetter(name_);

getParam(power_, "cost_power", 1);
getParam(weight_, "cost_weight", 5.0);
getParam(threshold_to_consider_, "threshold_to_consider", 1.4);
getParam(weight_, "cost_weight", 5.0f);
getParam(threshold_to_consider_, "threshold_to_consider", 1.4f);

RCLCPP_INFO(
logger_, "GoalCritic instantiated with %d power and %f weight.",
Expand All @@ -49,11 +49,9 @@ void GoalCritic::score(CriticData & data)
const auto traj_x = xt::view(data.trajectories.x, xt::all(), xt::all());
const auto traj_y = xt::view(data.trajectories.y, xt::all(), xt::all());

auto dists = xt::sqrt(
xt::pow(traj_x - goal_x, 2) +
xt::pow(traj_y - goal_y, 2));

data.costs += xt::pow(xt::mean(dists, {1}, immediate) * weight_, power_);
data.costs += xt::pow(
xt::mean(xt::hypot(traj_x - goal_x, traj_y - goal_y),
{1}, immediate) * weight_, power_);
}

} // namespace mppi::critics
Expand Down
10 changes: 5 additions & 5 deletions nav2_mppi_controller/src/critics/obstacles_critic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ void ObstaclesCritic::initialize()
auto getParam = parameters_handler_->getParamGetter(name_);
getParam(consider_footprint_, "consider_footprint", false);
getParam(power_, "cost_power", 1);
getParam(repulsion_weight_, "repulsion_weight", 1.5);
getParam(critical_weight_, "critical_weight", 20.0);
getParam(collision_cost_, "collision_cost", 100000.0);
getParam(collision_margin_distance_, "collision_margin_distance", 0.10);
getParam(near_goal_distance_, "near_goal_distance", 0.5);
getParam(repulsion_weight_, "repulsion_weight", 1.5f);
getParam(critical_weight_, "critical_weight", 20.0f);
getParam(collision_cost_, "collision_cost", 100000.0f);
getParam(collision_margin_distance_, "collision_margin_distance", 0.10f);
getParam(near_goal_distance_, "near_goal_distance", 0.5f);
getParam(inflation_layer_name_, "inflation_layer_name", std::string(""));

collision_checker_.setCostmap(costmap_);
Expand Down
Loading

0 comments on commit 7a0e651

Please sign in to comment.