Skip to content

Commit

Permalink
Fixed all the unit tests and critic tests, all unit tests passing loc…
Browse files Browse the repository at this point in the history
…ally

Signed-off-by: Ayush1285 <[email protected]>
  • Loading branch information
Ayush1285 committed Oct 2, 2024
1 parent 2c53db4 commit 847f837
Show file tree
Hide file tree
Showing 11 changed files with 116 additions and 106 deletions.
29 changes: 15 additions & 14 deletions nav2_mppi_controller/include/nav2_mppi_controller/tools/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -265,13 +265,12 @@ inline bool withinPositionGoalTolerance(
template<typename T>
auto normalize_angles(const T & angles)
{
// Eigen::ArrayXXf theta = (angles + M_PIF).unaryExpr([](const float x){
// float remainder = std::fmod(x, 2.0f * M_PIF);
// return remainder < 0.0f ? remainder + M_PIF : remainder - M_PIF;
// });
// return ((theta < 0.0f).select(theta + M_PIF, theta - M_PIF)).eval();
auto theta = angles - (M_PIF * ((angles + M_PIF_2) * (1.0f / M_PIF)).floor());
return theta;
return (angles + M_PIF).unaryExpr([&](const float x){
float remainder = std::fmod(x, 2.0f * M_PIF);
return remainder < 0.0f ? remainder + M_PIF : remainder - M_PIF;
});
// auto theta = angles - (M_PIF * ((angles + M_PIF_2) * (1.0f / M_PIF)).floor());
// return theta;
}

/**
Expand Down Expand Up @@ -303,11 +302,12 @@ auto shortest_angular_distance(
*/
inline size_t findPathFurthestReachedPoint(const CriticData & data)
{
const auto traj_x = data.trajectories.x.col(-1);
const auto traj_y = data.trajectories.y.col(-1);
int traj_cols = data.trajectories.x.cols();
const auto traj_x = data.trajectories.x.col(traj_cols - 1);
const auto traj_y = data.trajectories.y.col(traj_cols - 1);

const auto dx = (data.path.x.transpose()).replicate(traj_x.rows(), 1) - traj_x;
const auto dy = (data.path.y.transpose()).replicate(traj_y.rows(), 1) - traj_y;
const auto dx = (data.path.x.transpose()).replicate(traj_x.rows(), 1).colwise() - traj_x;
const auto dy = (data.path.y.transpose()).replicate(traj_y.rows(), 1).colwise() - traj_y;

const auto dists = dx * dx + dy * dy;

Expand Down Expand Up @@ -453,8 +453,9 @@ inline void savitskyGolayFilter(
const models::OptimizerSettings & settings)
{
// Savitzky-Golay Quadratic, 9-point Coefficients
Eigen::Array<float, 9, 1> filter = {-21.0, 14.0, 39.0, 54.0, 59.0, 54.0, 39.0, 14.0, -21.0};
filter /= 231.0;
Eigen::Array<float, 9, 1> filter = {-21.0f, 14.0f, 39.0f, 54.0f, 59.0f, 54.0f, 39.0f, 14.0f,
-21.0f};
filter /= 231.0f;

const unsigned int num_sequences = control_sequence.vx.size() - 1;

Expand All @@ -464,7 +465,7 @@ inline void savitskyGolayFilter(
}

auto applyFilter = [&](const Eigen::Array<float, 9, 1> & data) -> float {
return (data * filter).sum();
return (data * filter).eval().sum();
};

auto applyFilterOverAxis =
Expand Down
19 changes: 10 additions & 9 deletions nav2_mppi_controller/src/critics/constraint_critic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@ void ConstraintCritic::score(CriticData & data)
if (diff != nullptr) {
if (power_ > 1u) {
data.costs += (((((data.state.vx - max_vel_).max(0.0f) + (min_vel_ - data.state.vx).
max(0.0f)) * data.model_dt).rowwise().sum().eval()) * weight_).pow(power_);
max(0.0f)) * data.model_dt).rowwise().sum().eval()) * weight_).pow(power_).eval();
} else {
data.costs += ((((data.state.vx - max_vel_).max(0.0f) + (min_vel_ - data.state.vx).
max(0.0f)) * data.model_dt).rowwise().sum().eval()) * weight_;
data.costs += (((((data.state.vx - max_vel_).max(0.0f) + (min_vel_ - data.state.vx).
max(0.0f)) * data.model_dt).rowwise().sum().eval()) * weight_).eval();
}
return;
}
Expand All @@ -69,10 +69,10 @@ void ConstraintCritic::score(CriticData & data)
auto vel_total = (data.state.vx.square() + data.state.vy.square()).sqrt() * sgn;
if (power_ > 1u) {
data.costs += ((std::move((vel_total - max_vel_).max(0.0f) + (min_vel_ - vel_total).
max(0.0f)) * data.model_dt).rowwise().sum().eval() * weight_).pow(power_);
max(0.0f)) * data.model_dt).rowwise().sum().eval() * weight_).pow(power_).eval();
} else {
data.costs += (std::move((vel_total - max_vel_).max(0.0f) + (min_vel_ - vel_total).
max(0.0f)) * data.model_dt).rowwise().sum().eval() * weight_;
data.costs += ((std::move((vel_total - max_vel_).max(0.0f) + (min_vel_ - vel_total).
max(0.0f)) * data.model_dt).rowwise().sum().eval() * weight_).eval();
}
return;
}
Expand All @@ -86,10 +86,11 @@ void ConstraintCritic::score(CriticData & data)
auto out_of_turning_rad_motion = (min_turning_rad - (vx.abs() / wz.abs())).max(0.0f);
if (power_ > 1u) {
data.costs += ((std::move((vx - max_vel_).max(0.0f) + (min_vel_ - vx).max(0.0f) +
out_of_turning_rad_motion) * data.model_dt).rowwise().sum().eval() * weight_).pow(power_);
out_of_turning_rad_motion) * data.model_dt).rowwise().sum().eval() *
weight_).pow(power_).eval();
} else {
data.costs += (std::move((vx - max_vel_).max(0.0f) + (min_vel_ - vx).max(0.0f) +
out_of_turning_rad_motion) * data.model_dt).rowwise().sum().eval() * weight_;
data.costs += ((std::move((vx - max_vel_).max(0.0f) + (min_vel_ - vx).max(0.0f) +
out_of_turning_rad_motion) * data.model_dt).rowwise().sum().eval() * weight_).eval();
}
return;
}
Expand Down
6 changes: 3 additions & 3 deletions nav2_mppi_controller/src/critics/goal_angle_critic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ void GoalAngleCritic::score(CriticData & data)

if(power_ > 1u) {
data.costs += (((utils::shortest_angular_distance(data.trajectories.yaws, goal_yaw).abs()).
rowwise().mean()) * weight_).pow(power_);
rowwise().mean()) * weight_).pow(power_).eval();
} else {
data.costs += ((utils::shortest_angular_distance(data.trajectories.yaws, goal_yaw).abs()).
rowwise().mean()) * weight_;
data.costs += (((utils::shortest_angular_distance(data.trajectories.yaws, goal_yaw).abs()).
rowwise().mean()) * weight_).eval();
}
}

Expand Down
7 changes: 4 additions & 3 deletions nav2_mppi_controller/src/critics/path_align_critic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,9 @@ void PathAlignCritic::score(CriticData & data)
}
}

const size_t batch_size = data.trajectories.x.cols();
const size_t batch_size = data.trajectories.x.rows();
Eigen::ArrayXf cost(data.costs.rows());
cost.setZero();

// Find integrated distance in the path
std::vector<float> path_integrated_distances(path_segments_count, 0.0f);
Expand Down Expand Up @@ -151,9 +152,9 @@ void PathAlignCritic::score(CriticData & data)
}

if (power_ > 1u) {
data.costs += (std::move(cost) * weight_).pow(power_);
data.costs += (cost * weight_).pow(power_).eval();
} else {
data.costs += std::move(cost) * weight_;
data.costs += (cost * weight_).eval();
}
}

Expand Down
15 changes: 9 additions & 6 deletions nav2_mppi_controller/src/critics/path_angle_critic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,17 +97,18 @@ void PathAngleCritic::score(CriticData & data)
throw nav2_core::ControllerException("Invalid path angle mode!");
}

int && rightmost_idx = data.trajectories.y.cols() - 1;
int rightmost_idx = data.trajectories.y.cols() - 1;
auto diff_y = goal_y - data.trajectories.y.col(rightmost_idx);
auto diff_x = goal_x - data.trajectories.x.col(rightmost_idx);
auto yaws_between_points = diff_y.binaryExpr(
diff_x, [&](const float & y, const float & x){return atan2f(y, x);});
diff_x, [&](const float & y, const float & x){return atan2f(y, x);}).eval();

switch (mode_) {
case PathAngleMode::FORWARD_PREFERENCE:
{
auto rightmost_yaw = data.trajectories.yaws.col(rightmost_idx);
auto yaws = utils::shortest_angular_distance(
data.trajectories.yaws.col(rightmost_idx), yaws_between_points).abs();
rightmost_yaw, yaws_between_points).abs();
if (power_ > 1u) {
data.costs += (yaws * weight_).pow(power_);
} else {
Expand All @@ -117,11 +118,12 @@ void PathAngleCritic::score(CriticData & data)
}
case PathAngleMode::NO_DIRECTIONAL_PREFERENCE:
{
auto rightmost_yaw = data.trajectories.yaws.col(rightmost_idx);
auto yaws = utils::shortest_angular_distance(
data.trajectories.yaws.col(rightmost_idx), yaws_between_points).abs();
rightmost_yaw, yaws_between_points).abs();
auto yaws_between_points_corrected = utils::point_corrected_yaws(yaws, yaws_between_points);
auto corrected_yaws = utils::shortest_angular_distance(
data.trajectories.yaws.col(rightmost_idx), yaws_between_points_corrected).abs();
rightmost_yaw, yaws_between_points_corrected).abs();
if (power_ > 1u) {
data.costs += (corrected_yaws * weight_).pow(power_);
} else {
Expand All @@ -131,10 +133,11 @@ void PathAngleCritic::score(CriticData & data)
}
case PathAngleMode::CONSIDER_FEASIBLE_PATH_ORIENTATIONS:
{
auto rightmost_yaw = data.trajectories.yaws.col(rightmost_idx);
auto yaws_between_points_corrected =
utils::point_corrected_yaws(yaws_between_points, goal_yaw);
auto corrected_yaws = utils::shortest_angular_distance(
data.trajectories.yaws.col(rightmost_idx), yaws_between_points_corrected).abs();
rightmost_yaw, yaws_between_points_corrected).abs();
if (power_ > 1u) {
data.costs += (corrected_yaws * weight_).pow(power_);
} else {
Expand Down
4 changes: 2 additions & 2 deletions nav2_mppi_controller/src/critics/twirling_critic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ void TwirlingCritic::score(CriticData & data)
}

if (power_ > 1u) {
data.costs += Eigen::pow((Eigen::abs(data.state.wz)).rowwise().mean().eval() * weight_, power_);
data.costs += ((data.state.wz.abs().rowwise().mean()) * weight_).pow(power_).eval();
} else {
data.costs += (Eigen::abs(data.state.wz)).rowwise().mean().eval() * weight_;
data.costs += ((data.state.wz.abs().rowwise().mean()) * weight_).eval();
}
}

Expand Down
12 changes: 6 additions & 6 deletions nav2_mppi_controller/src/critics/velocity_deadband_critic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,12 @@ void VelocityDeadbandCritic::score(CriticData & data)
data.costs += ((((fabs(deadband_velocities_[0]) - data.state.vx.abs()).max(0.0f) +
(fabs(deadband_velocities_[1]) - data.state.vy.abs()).max(0.0f) +
(fabs(deadband_velocities_[2]) - data.state.wz.abs()).max(0.0f)) *
data.model_dt).rowwise().sum() * weight_).pow(power_);
data.model_dt).rowwise().sum() * weight_).pow(power_).eval();
} else {
data.costs += (((fabs(deadband_velocities_[0]) - data.state.vx.abs()).max(0.0f) +
data.costs += ((((fabs(deadband_velocities_[0]) - data.state.vx.abs()).max(0.0f) +
(fabs(deadband_velocities_[1]) - data.state.vy.abs()).max(0.0f) +
(fabs(deadband_velocities_[2]) - data.state.wz.abs()).max(0.0f)) *
data.model_dt).rowwise().sum() * weight_;
data.model_dt).rowwise().sum() * weight_).eval();
}
return;
}
Expand All @@ -63,12 +63,12 @@ void VelocityDeadbandCritic::score(CriticData & data)
data.costs += ((((fabs(deadband_velocities_[0]) - data.state.vx.abs()).max(0.0f) +
(fabs(deadband_velocities_[1]) - data.state.vy.abs()).max(0.0f) +
(fabs(deadband_velocities_[2]) - data.state.wz.abs()).max(0.0f)) *
data.model_dt).rowwise().sum() * weight_).pow(power_);
data.model_dt).rowwise().sum() * weight_).pow(power_).eval();
} else {
data.costs += (((fabs(deadband_velocities_[0]) - data.state.vx.abs()).max(0.0f) +
data.costs += ((((fabs(deadband_velocities_[0]) - data.state.vx.abs()).max(0.0f) +
(fabs(deadband_velocities_[1]) - data.state.vy.abs()).max(0.0f) +
(fabs(deadband_velocities_[2]) - data.state.wz.abs()).max(0.0f)) *
data.model_dt).rowwise().sum() * weight_;
data.model_dt).rowwise().sum() * weight_).eval();
}
return;
}
Expand Down
4 changes: 2 additions & 2 deletions nav2_mppi_controller/src/trajectory_visualizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,8 @@ void TrajectoryVisualizer::add(
const float shape_1 = static_cast<float>(n_cols);
points_->markers.reserve(floor(n_rows / trajectory_step_) * floor(n_cols * time_step_));

for (size_t i = 0; i != n_rows; i += trajectory_step_) {
for (size_t j = 0; j != n_cols; j += time_step_) {
for (size_t i = 0; i < n_rows; i += trajectory_step_) {
for (size_t j = 0; j < n_cols; j += time_step_) {
const float j_flt = static_cast<float>(j);
float blue_component = 1.0f - j_flt / shape_1;
float green_component = j_flt / shape_1;
Expand Down
Loading

0 comments on commit 847f837

Please sign in to comment.