Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Humble enjoy mppi critic pub #18

Open
wants to merge 7 commits into
base: humble
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions nav2_mppi_controller/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ set(XTENSOR_USE_XSIMD 1)
find_package(ament_cmake REQUIRED)
find_package(xtensor REQUIRED)
find_package(xsimd REQUIRED)
find_package(rosidl_default_generators REQUIRED)

include_directories(
include
Expand All @@ -33,6 +34,7 @@ set(dependencies_pkgs
tf2_geometry_msgs
tf2_eigen
tf2_ros
std_msgs
)

foreach(pkg IN LISTS dependencies_pkgs)
Expand All @@ -41,6 +43,12 @@ endforeach()

nav2_package()

rosidl_generate_interfaces(${PROJECT_NAME}
"msg/CriticScore.msg"
"msg/CriticScores.msg"
DEPENDENCIES std_msgs
)

include(CheckCXXCompilerFlag)

check_cxx_compiler_flag("-mno-avx512f" COMPILER_SUPPORTS_AVX512)
Expand Down Expand Up @@ -120,8 +128,14 @@ if(BUILD_TESTING)
# add_subdirectory(benchmark)
endif()

rosidl_get_typesupport_target(cpp_typesupport_target
${PROJECT_NAME} rosidl_typesupport_cpp)

target_link_libraries(mppi_controller "${cpp_typesupport_target}")

ament_export_libraries(${libraries})
ament_export_dependencies(${dependencies_pkgs})
ament_export_dependencies(rosidl_default_runtime)
ament_export_include_directories(include)
pluginlib_export_plugin_description_file(nav2_core mppic.xml)
pluginlib_export_plugin_description_file(nav2_mppi_controller critics.xml)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
#include "nav2_mppi_controller/tools/trajectory_visualizer.hpp"
#include "nav2_mppi_controller/models/constraints.hpp"
#include "nav2_mppi_controller/tools/utils.hpp"
#include "nav2_mppi_controller/msg/critic_score.hpp"
#include "nav2_mppi_controller/msg/critic_scores.hpp"

#include "nav2_core/controller.hpp"
#include "nav2_core/goal_checker.hpp"
Expand Down Expand Up @@ -121,10 +123,14 @@ class MPPIController : public nav2_core::Controller
TrajectoryVisualizer trajectory_visualizer_;

bool visualize_;
bool publish_critics_;

double reset_period_;
// Last time computeVelocityCommands was called
rclcpp::Time last_time_called_;

std::shared_ptr<rclcpp_lifecycle::LifecyclePublisher<nav2_mppi_controller::msg::CriticScores>>
critics_publisher_;
};

} // namespace nav2_mppi_controller
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,12 @@ class CriticManager
* @brief Constructor for mppi::CriticManager
*/
CriticManager() = default;

/**
* @brief Virtual Destructor for mppi::CriticManager
*/
virtual ~CriticManager() = default;

/**
* @brief Configure critic manager on bringup and load plugins
* @param parent WeakPtr to node
Expand All @@ -69,6 +69,8 @@ class CriticManager
*/
void evalTrajectoriesScores(CriticData & data) const;

std::vector<std::string> getCriticNames() const;

protected:
/**
* @brief Get parameters (critics to load)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,15 @@ class Optimizer
*/
xt::xtensor<float, 2> getOptimizedTrajectory();


/**
* @brief Get the critic costs for given trajectory
* @return Names and costs of the critics
*/
xt::xtensor<float, 1> getOptimizationResults();

std::vector<std::string> getCriticNames() const;

/**
* @brief Set the maximum speed based on the speed limits callback
* @param speed_limit Limit of the speed for use
Expand Down
2 changes: 2 additions & 0 deletions nav2_mppi_controller/msg/CriticScore.msg
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
std_msgs/String name
std_msgs/Float32 score
2 changes: 2 additions & 0 deletions nav2_mppi_controller/msg/CriticScores.msg
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
std_msgs/Header header # when msg was sent
CriticScore[] critic_scores
5 changes: 5 additions & 0 deletions nav2_mppi_controller/package.xml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@

<buildtool_depend>ament_cmake</buildtool_depend>
<buildtool_depend>ament_cmake_ros</buildtool_depend>
<buildtool_depend>rosidl_default_generators</buildtool_depend>

<exec_depend>rosidl_default_runtime</exec_depend>

<depend>rclcpp</depend>
<depend>nav2_common</depend>
Expand All @@ -33,6 +36,8 @@
<test_depend>ament_lint_auto</test_depend>
<test_depend>ament_lint_common</test_depend>
<test_depend>ament_cmake_gtest</test_depend>

<member_of_group>rosidl_interface_packages</member_of_group>
<export>
<build_type>ament_cmake</build_type>
<nav2_core plugin="${prefix}/mppic.xml" />
Expand Down
38 changes: 38 additions & 0 deletions nav2_mppi_controller/src/controller.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ void MPPIController::configure(
// Get high-level controller parameters
auto getParam = parameters_handler_->getParamGetter(name_);
getParam(visualize_, "visualize", false);
getParam(publish_critics_, "publish_critics", false);
getParam(reset_period_, "reset_period", 1.0);

// Configure composed objects
Expand All @@ -48,6 +49,11 @@ void MPPIController::configure(
parent_, name_,
costmap_ros_->getGlobalFrameID(), parameters_handler_.get());

if (publish_critics_) {
critics_publisher_ = node->create_publisher<nav2_mppi_controller::msg::CriticScores>(
"/mppi_critic_scores", 1);
}

RCLCPP_INFO(logger_, "Configured MPPI Controller: %s", name_.c_str());
}

Expand All @@ -61,13 +67,15 @@ void MPPIController::cleanup()

void MPPIController::activate()
{
if (publish_critics_) critics_publisher_->on_activate();
trajectory_visualizer_.on_activate();
parameters_handler_->start();
RCLCPP_INFO(logger_, "Activated MPPI Controller: %s", name_.c_str());
}

void MPPIController::deactivate()
{
if (publish_critics_) critics_publisher_->on_deactivate();
trajectory_visualizer_.on_deactivate();
RCLCPP_INFO(logger_, "Deactivated MPPI Controller: %s", name_.c_str());
}
Expand Down Expand Up @@ -110,6 +118,36 @@ geometry_msgs::msg::TwistStamped MPPIController::computeVelocityCommands(
visualize(std::move(transformed_plan));
}

if (publish_critics_) {
std::vector<std::string> critic_names = optimizer_.getCriticNames();
xt::xtensor<float, 1> critic_costs = optimizer_.getOptimizationResults();

// log critic names and costs
for (size_t i = 0; i < critic_names.size(); i++) {
RCLCPP_DEBUG(logger_, "Critic: %s, Cost: %f", critic_names[i].c_str(), critic_costs[i]);
}

// make msg
auto critic_scores_ = std::make_unique<nav2_mppi_controller::msg::CriticScores>();
if (critic_names.size() != critic_costs.size()) {
RCLCPP_ERROR(
logger_,
"Critic names %ld and costs %ld size mismatch!",
critic_names.size(), critic_costs.size());
return cmd;
}

for (size_t i = 0; i < critic_names.size(); i++) {
nav2_mppi_controller::msg::CriticScore critic_score;
critic_score.name.data = critic_names[i];
critic_score.score.data = critic_costs[i];
critic_scores_->critic_scores.push_back(critic_score);
}

critic_scores_->header.stamp = clock_->now();
critics_publisher_->publish(std::move(critic_scores_));
}

return cmd;
}

Expand Down
8 changes: 7 additions & 1 deletion nav2_mppi_controller/src/critic_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include <xtensor/xtensor.hpp>

#include "nav2_mppi_controller/critic_manager.hpp"

namespace mppi
Expand Down Expand Up @@ -64,6 +66,11 @@ std::string CriticManager::getFullName(const std::string & name)
return "mppi::critics::" + name;
}

std::vector<std::string> CriticManager::getCriticNames() const
{
return critic_names_;
}

void CriticManager::evalTrajectoriesScores(
CriticData & data) const
{
Expand All @@ -74,5 +81,4 @@ void CriticManager::evalTrajectoriesScores(
critics_[q]->score(data);
}
}

} // namespace mppi
42 changes: 42 additions & 0 deletions nav2_mppi_controller/src/optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,48 @@ void Optimizer::optimize()
}
}

xt::xtensor<float, 1> Optimizer::getOptimizationResults()
{
// get the final optimized trajectory
const xt::xtensor<float, 2> optimized_trajectory = getOptimizedTrajectory();
xt::xtensor<float, 1> costs = xt::zeros<float>({1});

// evalTrajectory evals multiple trajectories, but we only have one
// create a dummy_trajectories object and put the optimized in it
models::Trajectories dummy_trajectories;
dummy_trajectories.reset(1, settings_.time_steps);
dummy_trajectories.x += xt::view(optimized_trajectory, xt::all(), 0);
dummy_trajectories.y += xt::view(optimized_trajectory, xt::all(), 1);
dummy_trajectories.yaws += xt::view(optimized_trajectory, xt::all(), 2);

// create a dummy_data object to pass to evalTrajectory
CriticData dummy_data = {
state_, dummy_trajectories, path_, costs, settings_.model_dt,
false, critics_data_.goal_checker, critics_data_.motion_model, std::nullopt, std::nullopt};
dummy_data.furthest_reached_path_point.reset();
dummy_data.path_pts_valid.reset();

// use evalTrajectoriesScores
critic_manager_.evalTrajectoriesScores(dummy_data);

size_t num_critics = critic_manager_.getCriticNames().size();

xt::xtensor<float, 1> critic_scores = xt::zeros<float>(std::vector<size_t>{num_critics});
for (size_t i = 0; i < num_critics; i++) {
critic_scores(i) = dummy_data.costs(0); // Assuming costs are updated for each critic
}

return critic_scores;

// evaluate the optimized trajectory
// return critic_manager_.evalTrajectory(dummy_data);
}

std::vector<std::string> Optimizer::getCriticNames() const
{
return critic_manager_.getCriticNames();
}

bool Optimizer::fallback(bool fail)
{
static size_t counter = 0;
Expand Down