Skip to content

Commit

Permalink
Fixed CostCritic issue and added test for shiftColumn method
Browse files Browse the repository at this point in the history
Signed-off-by: Ayush1285 <[email protected]>
  • Loading branch information
Ayush1285 committed Oct 16, 2024
1 parent 0327754 commit 2f8bd63
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,10 @@ class MotionModel
virtual void predict(models::State & state)
{
const bool is_holo = isHolonomic();
float && max_delta_vx = model_dt_ * control_constraints_.ax_max;
float && min_delta_vx = model_dt_ * control_constraints_.ax_min;
float && max_delta_vy = model_dt_ * control_constraints_.ay_max;
float && max_delta_wz = model_dt_ * control_constraints_.az_max;
float max_delta_vx = model_dt_ * control_constraints_.ax_max;
float min_delta_vx = model_dt_ * control_constraints_.ax_min;
float max_delta_vy = model_dt_ * control_constraints_.ay_max;
float max_delta_wz = model_dt_ * control_constraints_.az_max;

unsigned int n_rows = state.vx.rows();
unsigned int n_cols = state.vx.cols();
Expand All @@ -77,18 +77,18 @@ class MotionModel
// column-major fashion to utilize L1 cache as much as possible
for (unsigned int i = 1; i != n_cols; i++) {
for (unsigned int j = 0; j != n_rows; j++) {
float & vx_last = state.vx(j, i - 1);
float vx_last = state.vx(j, i - 1);
float & cvx_curr = state.cvx(j, i - 1);
cvx_curr = std::min(vx_last + max_delta_vx, std::max(cvx_curr, vx_last + min_delta_vx));
state.vx(j, i) = cvx_curr;

float & wz_last = state.wz(j, i - 1);
float wz_last = state.wz(j, i - 1);
float & cwz_curr = state.cwz(j, i - 1);
cwz_curr = std::min(wz_last + max_delta_wz, std::max(cwz_curr, wz_last - max_delta_wz));
state.wz(j, i) = cwz_curr;

if (is_holo) {
float & vy_last = state.vy(j, i - 1);
float vy_last = state.vy(j, i - 1);
float & cvy_curr = state.cvy(j, i - 1);
cvy_curr = std::min(vy_last + max_delta_vy, std::max(cvy_curr, vy_last - max_delta_vy));
state.vy(j, i) = cvy_curr;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,7 @@ struct Pose2D
inline void shiftColumnsByOnePlace(Eigen::Ref<Eigen::ArrayXXf> e, int direction)
{
int size = e.size();
if(size == 1) {return;}
if((e.cols() == 1 || e.rows() == 1) && size > 1) {
auto start_ptr = direction == 1 ? e.data() + size - 2 : e.data() + 1;
auto end_ptr = direction == 1 ? e.data() : e.data() + size - 1;
Expand Down Expand Up @@ -648,12 +649,6 @@ inline auto point_corrected_yaws(
yaws_between_points_corrected[i] = yaws[i] < M_PIF_2 ?
yaw_between_points : angles::normalize_angle(yaw_between_points + M_PIF);
}

// binaryExpr slower than for loop
// Eigen::ArrayXf yaws_between_points_corrected = yaws.binaryExpr(
// yaws_between_points, [&](const float & yaw, const float & yaw_between_points) {
// return yaw < M_PIF_2 ? yaw_between_points : normalize_anglef(yaw_between_points + M_PIF);
// });
return yaws_between_points_corrected;
}

Expand All @@ -668,13 +663,6 @@ inline auto point_corrected_yaws(
angles::normalize_angle(yaw_between_points - goal_yaw)) < M_PIF_2 ?
yaw_between_points : angles::normalize_angle(yaw_between_points + M_PIF);
}

// unaryExpr slower than for loop
// Eigen::ArrayXf yaws_between_points_corrected = yaws_between_points.unaryExpr(
// [&](const float & yaw_between_points) {
// return fabs(normalize_anglef(yaw_between_points - goal_yaw)) < M_PIF_2 ?
// yaw_between_points : normalize_anglef(yaw_between_points + M_PIF);
// });
return yaws_between_points_corrected;
}

Expand Down
6 changes: 3 additions & 3 deletions nav2_mppi_controller/src/critics/cost_critic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,10 @@ void CostCritic::score(CriticData & data)
}

Eigen::ArrayXf repulsive_cost(data.costs.rows());
repulsive_cost.setZero();
bool all_trajectories_collide = true;

int strided_traj_cols = data.trajectories.x.cols() / trajectory_point_step_ + 1;
int strided_traj_cols = floor(data.trajectories.x.cols() / trajectory_point_step_);
int strided_traj_rows = data.trajectories.x.rows();
int outer_stride = strided_traj_rows * trajectory_point_step_;

Expand All @@ -140,8 +141,7 @@ void CostCritic::score(CriticData & data)
for (int i = 0; i < strided_traj_rows; ++i) {
bool trajectory_collide = false;
float pose_cost = 0.0f;
float & traj_cost = repulsive_cost[i];
traj_cost = 0.0f;
float & traj_cost = repulsive_cost(i);

for (int j = 0; j < strided_traj_cols; j++) {
float Tx = traj_x(i, j);
Expand Down
26 changes: 13 additions & 13 deletions nav2_mppi_controller/src/optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ void Optimizer::reset()

settings_.constraints = settings_.base_constraints;

costs_ = Eigen::ArrayXf::Zero(settings_.batch_size);
costs_.setZero(settings_.batch_size);
generated_trajectories_.reset(settings_.batch_size, settings_.time_steps);

noise_generator_.reset(settings_, isHolonomic());
Expand Down Expand Up @@ -202,7 +202,7 @@ void Optimizer::prepare(
state_.pose = robot_pose;
state_.speed = robot_speed;
path_ = utils::toTensor(plan);
costs_.fill(0.0f);
costs_.setZero();

critics_data_.fail_flag = false;
critics_data_.goal_checker = goal_checker;
Expand Down Expand Up @@ -404,21 +404,21 @@ void Optimizer::updateControlSequence()
{
const bool is_holo = isHolonomic();
auto & s = settings_;
auto && bounded_noises_vx = state_.cvx.rowwise() - control_sequence_.vx.transpose();
auto && bounded_noises_wz = state_.cwz.rowwise() - control_sequence_.wz.transpose();
costs_ += s.gamma / powf(s.sampling_std.vx, 2) *
(bounded_noises_vx.rowwise() * control_sequence_.vx.transpose()).rowwise().sum();
costs_ += s.gamma / powf(s.sampling_std.wz, 2) *
(bounded_noises_wz.rowwise() * control_sequence_.wz.transpose()).rowwise().sum();
auto bounded_noises_vx = state_.cvx.rowwise() - control_sequence_.vx.transpose();
auto bounded_noises_wz = state_.cwz.rowwise() - control_sequence_.wz.transpose();
costs_ += (s.gamma / powf(s.sampling_std.vx, 2) *
(bounded_noises_vx.rowwise() * control_sequence_.vx.transpose()).rowwise().sum()).eval();
costs_ += (s.gamma / powf(s.sampling_std.wz, 2) *
(bounded_noises_wz.rowwise() * control_sequence_.wz.transpose()).rowwise().sum()).eval();
if (is_holo) {
auto bounded_noises_vy = state_.cvy.rowwise() - control_sequence_.vy.transpose();
costs_ += s.gamma / powf(s.sampling_std.vy, 2) *
(bounded_noises_vy.rowwise() * control_sequence_.vy.transpose()).rowwise().sum();
costs_ += (s.gamma / powf(s.sampling_std.vy, 2) *
(bounded_noises_vy.rowwise() * control_sequence_.vy.transpose()).rowwise().sum()).eval();
}

auto && costs_normalized = costs_ - costs_.minCoeff();
auto && exponents = ((-1 / settings_.temperature * costs_normalized).exp()).eval();
auto && softmaxes = (exponents / exponents.sum()).eval();
auto costs_normalized = costs_ - costs_.minCoeff();
auto exponents = ((-1 / settings_.temperature * costs_normalized).exp()).eval();
auto softmaxes = (exponents / exponents.sum()).eval();

control_sequence_.vx = (state_.cvx.colwise() * softmaxes).colwise().sum();
control_sequence_.wz = (state_.cwz.colwise() * softmaxes).colwise().sum();
Expand Down
65 changes: 65 additions & 0 deletions nav2_mppi_controller/test/utils_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -447,3 +447,68 @@ TEST(UtilsTests, RemovePosesAfterPathInversionTest)
EXPECT_EQ(path.poses.size(), 11u);
EXPECT_EQ(path.poses.back().pose.position.x, 10);
}

TEST(UtilsTests, ShiftColumnsByOnePlaceTest)
{
// Try with scalar value
Eigen::ArrayXf scalar_val(1);
scalar_val(0) = 5;
utils::shiftColumnsByOnePlace(scalar_val, 1);
EXPECT_EQ(scalar_val.size(), 1);
EXPECT_EQ(scalar_val(0), 5);

// Try with one dimensional array, shift right
Eigen::ArrayXf array_1d(4);
array_1d << 1, 2, 3, 4;
utils::shiftColumnsByOnePlace(array_1d, 1);
EXPECT_EQ(array_1d.size(), 4);
EXPECT_EQ(array_1d(1), 1);
EXPECT_EQ(array_1d(2), 2);
EXPECT_EQ(array_1d(3), 3);

// Try with one dimensional array, shift left
array_1d(1) = 5;
utils::shiftColumnsByOnePlace(array_1d, -1);
EXPECT_EQ(array_1d.size(), 4);
EXPECT_EQ(array_1d(0), 5);
EXPECT_EQ(array_1d(1), 2);
EXPECT_EQ(array_1d(2), 3);

// Try with two dimensional array, shift right
// 1 2 3 4 1 1 2 3
// 5 6 7 8 -> 5 5 6 7
// 9 10 11 12 9 9 10 11
Eigen::ArrayXXf array_2d(3, 4);
array_2d << 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12;
utils::shiftColumnsByOnePlace(array_2d, 1);
EXPECT_EQ(array_2d.rows(), 3);
EXPECT_EQ(array_2d.cols(), 4);
EXPECT_EQ(array_2d(0, 1), 1);
EXPECT_EQ(array_2d(1, 1), 5);
EXPECT_EQ(array_2d(2, 1), 9);
EXPECT_EQ(array_2d(0, 2), 2);
EXPECT_EQ(array_2d(1, 2), 6);
EXPECT_EQ(array_2d(2, 2), 10);
EXPECT_EQ(array_2d(0, 3), 3);
EXPECT_EQ(array_2d(1, 3), 7);
EXPECT_EQ(array_2d(2, 3), 11);

array_2d.col(0).setZero();

// Try with two dimensional array, shift left
// 0 1 2 3 1 2 3 3
// 0 5 6 7 -> 5 6 7 7
// 0 9 10 11 9 10 11 11
utils::shiftColumnsByOnePlace(array_2d, -1);
EXPECT_EQ(array_2d.rows(), 3);
EXPECT_EQ(array_2d.cols(), 4);
EXPECT_EQ(array_2d(0, 0), 1);
EXPECT_EQ(array_2d(1, 0), 5);
EXPECT_EQ(array_2d(2, 0), 9);
EXPECT_EQ(array_2d(0, 1), 2);
EXPECT_EQ(array_2d(1, 1), 6);
EXPECT_EQ(array_2d(2, 1), 10);
EXPECT_EQ(array_2d(0, 2), 3);
EXPECT_EQ(array_2d(1, 2), 7);
EXPECT_EQ(array_2d(2, 2), 11);
}

0 comments on commit 2f8bd63

Please sign in to comment.