Skip to content

Commit

Permalink
adding in path angle update
Browse files Browse the repository at this point in the history
  • Loading branch information
SteveMacenski committed Mar 11, 2024
1 parent 7a0e651 commit 81f5480
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 14 deletions.
8 changes: 4 additions & 4 deletions nav2_mppi_controller/src/critics/cost_critic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,9 @@ void CostCritic::score(CriticData & data)
bool all_trajectories_collide = true;
for (size_t i = 0; i < data.trajectories.x.shape(0); ++i) {
bool trajectory_collide = false;
const auto & traj_x = xt::view(data.trajectories.x, i, xt::all());
const auto & traj_y = xt::view(data.trajectories.y, i, xt::all());
const auto & traj_yaw = xt::view(data.trajectories.yaws, i, xt::all());
const auto traj_x = xt::view(data.trajectories.x, i, xt::all());
const auto traj_y = xt::view(data.trajectories.y, i, xt::all());
const auto traj_yaw = xt::view(data.trajectories.yaws, i, xt::all());
pose_cost = 0.0f;

for (size_t j = 0; j < traj_len; j++) {
Expand Down Expand Up @@ -157,7 +157,7 @@ void CostCritic::score(CriticData & data)
}
}

data.costs += xt::pow((weight_ * repulsive_cost / traj_len), power_);
data.costs += xt::pow((std::move(repulsive_cost) * (weight_ / static_cast<float>(traj_len))), power_);
data.fail_flag = all_trajectories_collide;
}

Expand Down
3 changes: 1 addition & 2 deletions nav2_mppi_controller/src/critics/path_align_critic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ void PathAlignCritic::score(CriticData & data)
const size_t batch_size = data.trajectories.x.shape(0);
const size_t time_steps = data.trajectories.x.shape(1);
auto && cost = xt::xtensor<float, 1>::from_shape({data.costs.shape(0)});
cost.fill(0.0f);

// Find integrated distance in the path
std::vector<float> path_integrated_distances(path_segments_count, 0.0f);
Expand Down Expand Up @@ -126,8 +127,6 @@ void PathAlignCritic::score(CriticData & data)
}
if (num_samples > 0) {
cost[t] = summed_path_dist / num_samples;
} else {
cost[t] = 0.0f;
}
}

Expand Down
21 changes: 13 additions & 8 deletions nav2_mppi_controller/src/critics/path_angle_critic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,24 +99,28 @@ void PathAngleCritic::score(CriticData & data)
}

auto yaws_between_points = xt::atan2(
goal_y - data.trajectories.y,
goal_x - data.trajectories.x);

auto yaws =
xt::abs(utils::shortest_angular_distance(data.trajectories.yaws, yaws_between_points));
goal_y - xt::view(data.trajectories.y, xt::all(), -1),
goal_x - xt::view(data.trajectories.x, xt::all(), -1));

switch (mode_) {
case PathAngleMode::FORWARD_PREFERENCE:
{
data.costs += xt::pow(xt::mean(yaws, {1}, immediate) * weight_, power_);
auto yaws =
xt::abs(utils::shortest_angular_distance(
xt::view(data.trajectories.yaws, xt::all(), -1), yaws_between_points));
data.costs += xt::pow(yaws * weight_, power_);
return;
}
case PathAngleMode::NO_DIRECTIONAL_PREFERENCE:
{
auto yaws =
xt::abs(utils::shortest_angular_distance(
xt::view(data.trajectories.yaws, xt::all(), -1), yaws_between_points));
const auto yaws_between_points_corrected = xt::where(
yaws < M_PIF_2, yaws_between_points, utils::normalize_angles(yaws_between_points + M_PIF));
const auto corrected_yaws = xt::abs(
utils::shortest_angular_distance(data.trajectories.yaws, yaws_between_points_corrected));
utils::shortest_angular_distance(
xt::view(data.trajectories.yaws, xt::all(), -1), yaws_between_points_corrected));
data.costs += xt::pow(xt::mean(corrected_yaws, {1}, immediate) * weight_, power_);
return;
}
Expand All @@ -126,7 +130,8 @@ void PathAngleCritic::score(CriticData & data)
xt::abs(utils::shortest_angular_distance(yaws_between_points, goal_yaw)) < M_PIF_2,
yaws_between_points, utils::normalize_angles(yaws_between_points + M_PIF));
const auto corrected_yaws = xt::abs(
utils::shortest_angular_distance(data.trajectories.yaws, yaws_between_points_corrected));
utils::shortest_angular_distance(
xt::view(data.trajectories.yaws, xt::all(), -1), yaws_between_points_corrected));
data.costs += xt::pow(xt::mean(corrected_yaws, {1}, immediate) * weight_, power_);
return;
}
Expand Down

0 comments on commit 81f5480

Please sign in to comment.