diff --git a/nav2_mppi_controller/CMakeLists.txt b/nav2_mppi_controller/CMakeLists.txt index 3ca6735e40b..dac13f6edef 100644 --- a/nav2_mppi_controller/CMakeLists.txt +++ b/nav2_mppi_controller/CMakeLists.txt @@ -14,6 +14,7 @@ set(XTENSOR_USE_XSIMD 1) find_package(ament_cmake REQUIRED) find_package(xtensor REQUIRED) find_package(xsimd REQUIRED) +find_package(rosidl_default_generators REQUIRED) include_directories( include @@ -33,6 +34,7 @@ set(dependencies_pkgs tf2_geometry_msgs tf2_eigen tf2_ros + std_msgs ) foreach(pkg IN LISTS dependencies_pkgs) @@ -41,6 +43,12 @@ endforeach() nav2_package() +rosidl_generate_interfaces(${PROJECT_NAME} + "msg/CriticScore.msg" + "msg/CriticScores.msg" + DEPENDENCIES std_msgs +) + include(CheckCXXCompilerFlag) check_cxx_compiler_flag("-mno-avx512f" COMPILER_SUPPORTS_AVX512) @@ -120,8 +128,14 @@ if(BUILD_TESTING) # add_subdirectory(benchmark) endif() +rosidl_get_typesupport_target(cpp_typesupport_target + ${PROJECT_NAME} rosidl_typesupport_cpp) + +target_link_libraries(mppi_controller "${cpp_typesupport_target}") + ament_export_libraries(${libraries}) ament_export_dependencies(${dependencies_pkgs}) +ament_export_dependencies(rosidl_default_runtime) ament_export_include_directories(include) pluginlib_export_plugin_description_file(nav2_core mppic.xml) pluginlib_export_plugin_description_file(nav2_mppi_controller critics.xml) diff --git a/nav2_mppi_controller/include/nav2_mppi_controller/controller.hpp b/nav2_mppi_controller/include/nav2_mppi_controller/controller.hpp index 2d7f2aa8ca4..11b5a3875da 100644 --- a/nav2_mppi_controller/include/nav2_mppi_controller/controller.hpp +++ b/nav2_mppi_controller/include/nav2_mppi_controller/controller.hpp @@ -23,6 +23,8 @@ #include "nav2_mppi_controller/tools/trajectory_visualizer.hpp" #include "nav2_mppi_controller/models/constraints.hpp" #include "nav2_mppi_controller/tools/utils.hpp" +#include "nav2_mppi_controller/msg/critic_score.hpp" +#include "nav2_mppi_controller/msg/critic_scores.hpp" #include "nav2_core/controller.hpp" #include "nav2_core/goal_checker.hpp" @@ -121,10 +123,14 @@ class MPPIController : public nav2_core::Controller TrajectoryVisualizer trajectory_visualizer_; bool visualize_; + bool publish_critics_; double reset_period_; // Last time computeVelocityCommands was called rclcpp::Time last_time_called_; + + std::shared_ptr> + critics_publisher_; }; } // namespace nav2_mppi_controller diff --git a/nav2_mppi_controller/include/nav2_mppi_controller/critic_manager.hpp b/nav2_mppi_controller/include/nav2_mppi_controller/critic_manager.hpp index b2e2c178d96..be294b11701 100644 --- a/nav2_mppi_controller/include/nav2_mppi_controller/critic_manager.hpp +++ b/nav2_mppi_controller/include/nav2_mppi_controller/critic_manager.hpp @@ -46,12 +46,12 @@ class CriticManager * @brief Constructor for mppi::CriticManager */ CriticManager() = default; - + /** * @brief Virtual Destructor for mppi::CriticManager */ virtual ~CriticManager() = default; - + /** * @brief Configure critic manager on bringup and load plugins * @param parent WeakPtr to node @@ -69,6 +69,8 @@ class CriticManager */ void evalTrajectoriesScores(CriticData & data) const; + std::vector getCriticNames() const; + protected: /** * @brief Get parameters (critics to load) diff --git a/nav2_mppi_controller/include/nav2_mppi_controller/optimizer.hpp b/nav2_mppi_controller/include/nav2_mppi_controller/optimizer.hpp index b7200575f70..d2cb56903ef 100644 --- a/nav2_mppi_controller/include/nav2_mppi_controller/optimizer.hpp +++ b/nav2_mppi_controller/include/nav2_mppi_controller/optimizer.hpp @@ -104,6 +104,15 @@ class Optimizer */ xt::xtensor getOptimizedTrajectory(); + + /** + * @brief Get the critic costs for given trajectory + * @return Names and costs of the critics + */ + xt::xtensor getOptimizationResults(); + + std::vector getCriticNames() const; + /** * @brief Set the maximum speed based on the speed limits callback * @param speed_limit Limit of the speed for use diff --git a/nav2_mppi_controller/msg/CriticScore.msg b/nav2_mppi_controller/msg/CriticScore.msg new file mode 100644 index 00000000000..fdc2901a901 --- /dev/null +++ b/nav2_mppi_controller/msg/CriticScore.msg @@ -0,0 +1,2 @@ +std_msgs/String name +std_msgs/Float32 score diff --git a/nav2_mppi_controller/msg/CriticScores.msg b/nav2_mppi_controller/msg/CriticScores.msg new file mode 100644 index 00000000000..4da10efe3db --- /dev/null +++ b/nav2_mppi_controller/msg/CriticScores.msg @@ -0,0 +1,2 @@ +std_msgs/Header header # when msg was sent +CriticScore[] critic_scores diff --git a/nav2_mppi_controller/package.xml b/nav2_mppi_controller/package.xml index da7fbf07cb3..f56e1f95fb2 100644 --- a/nav2_mppi_controller/package.xml +++ b/nav2_mppi_controller/package.xml @@ -10,6 +10,9 @@ ament_cmake ament_cmake_ros + rosidl_default_generators + + rosidl_default_runtime rclcpp nav2_common @@ -33,6 +36,8 @@ ament_lint_auto ament_lint_common ament_cmake_gtest + + rosidl_interface_packages ament_cmake diff --git a/nav2_mppi_controller/src/controller.cpp b/nav2_mppi_controller/src/controller.cpp index 8a631f97526..dd005a29f42 100644 --- a/nav2_mppi_controller/src/controller.cpp +++ b/nav2_mppi_controller/src/controller.cpp @@ -39,6 +39,7 @@ void MPPIController::configure( // Get high-level controller parameters auto getParam = parameters_handler_->getParamGetter(name_); getParam(visualize_, "visualize", false); + getParam(publish_critics_, "publish_critics", false); getParam(reset_period_, "reset_period", 1.0); // Configure composed objects @@ -48,6 +49,11 @@ void MPPIController::configure( parent_, name_, costmap_ros_->getGlobalFrameID(), parameters_handler_.get()); + if (publish_critics_) { + critics_publisher_ = node->create_publisher( + "/mppi_critic_scores", 1); + } + RCLCPP_INFO(logger_, "Configured MPPI Controller: %s", name_.c_str()); } @@ -61,6 +67,7 @@ void MPPIController::cleanup() void MPPIController::activate() { + if (publish_critics_) critics_publisher_->on_activate(); trajectory_visualizer_.on_activate(); parameters_handler_->start(); RCLCPP_INFO(logger_, "Activated MPPI Controller: %s", name_.c_str()); @@ -68,6 +75,7 @@ void MPPIController::activate() void MPPIController::deactivate() { + if (publish_critics_) critics_publisher_->on_deactivate(); trajectory_visualizer_.on_deactivate(); RCLCPP_INFO(logger_, "Deactivated MPPI Controller: %s", name_.c_str()); } @@ -110,6 +118,36 @@ geometry_msgs::msg::TwistStamped MPPIController::computeVelocityCommands( visualize(std::move(transformed_plan)); } + if (publish_critics_) { + std::vector critic_names = optimizer_.getCriticNames(); + xt::xtensor critic_costs = optimizer_.getOptimizationResults(); + + // log critic names and costs + for (size_t i = 0; i < critic_names.size(); i++) { + RCLCPP_DEBUG(logger_, "Critic: %s, Cost: %f", critic_names[i].c_str(), critic_costs[i]); + } + + // make msg + auto critic_scores_ = std::make_unique(); + if (critic_names.size() != critic_costs.size()) { + RCLCPP_ERROR( + logger_, + "Critic names %ld and costs %ld size mismatch!", + critic_names.size(), critic_costs.size()); + return cmd; + } + + for (size_t i = 0; i < critic_names.size(); i++) { + nav2_mppi_controller::msg::CriticScore critic_score; + critic_score.name.data = critic_names[i]; + critic_score.score.data = critic_costs[i]; + critic_scores_->critic_scores.push_back(critic_score); + } + + critic_scores_->header.stamp = clock_->now(); + critics_publisher_->publish(std::move(critic_scores_)); + } + return cmd; } diff --git a/nav2_mppi_controller/src/critic_manager.cpp b/nav2_mppi_controller/src/critic_manager.cpp index 2a7a77a2346..0e21a7d8775 100644 --- a/nav2_mppi_controller/src/critic_manager.cpp +++ b/nav2_mppi_controller/src/critic_manager.cpp @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include + #include "nav2_mppi_controller/critic_manager.hpp" namespace mppi @@ -64,6 +66,11 @@ std::string CriticManager::getFullName(const std::string & name) return "mppi::critics::" + name; } +std::vector CriticManager::getCriticNames() const +{ + return critic_names_; +} + void CriticManager::evalTrajectoriesScores( CriticData & data) const { @@ -74,5 +81,4 @@ void CriticManager::evalTrajectoriesScores( critics_[q]->score(data); } } - } // namespace mppi diff --git a/nav2_mppi_controller/src/optimizer.cpp b/nav2_mppi_controller/src/optimizer.cpp index 3de703bcc44..c0a1f3d41dc 100644 --- a/nav2_mppi_controller/src/optimizer.cpp +++ b/nav2_mppi_controller/src/optimizer.cpp @@ -161,6 +161,48 @@ void Optimizer::optimize() } } +xt::xtensor Optimizer::getOptimizationResults() +{ + // get the final optimized trajectory + const xt::xtensor optimized_trajectory = getOptimizedTrajectory(); + xt::xtensor costs = xt::zeros({1}); + + // evalTrajectory evals multiple trajectories, but we only have one + // create a dummy_trajectories object and put the optimized in it + models::Trajectories dummy_trajectories; + dummy_trajectories.reset(1, settings_.time_steps); + dummy_trajectories.x += xt::view(optimized_trajectory, xt::all(), 0); + dummy_trajectories.y += xt::view(optimized_trajectory, xt::all(), 1); + dummy_trajectories.yaws += xt::view(optimized_trajectory, xt::all(), 2); + + // create a dummy_data object to pass to evalTrajectory + CriticData dummy_data = { + state_, dummy_trajectories, path_, costs, settings_.model_dt, + false, critics_data_.goal_checker, critics_data_.motion_model, std::nullopt, std::nullopt}; + dummy_data.furthest_reached_path_point.reset(); + dummy_data.path_pts_valid.reset(); + + // use evalTrajectoriesScores + critic_manager_.evalTrajectoriesScores(dummy_data); + + size_t num_critics = critic_manager_.getCriticNames().size(); + + xt::xtensor critic_scores = xt::zeros(std::vector{num_critics}); + for (size_t i = 0; i < num_critics; i++) { + critic_scores(i) = dummy_data.costs(0); // Assuming costs are updated for each critic + } + + return critic_scores; + + // evaluate the optimized trajectory + // return critic_manager_.evalTrajectory(dummy_data); +} + +std::vector Optimizer::getCriticNames() const +{ + return critic_manager_.getCriticNames(); +} + bool Optimizer::fallback(bool fail) { static size_t counter = 0;