diff --git a/nav2_behavior_tree/CMakeLists.txt b/nav2_behavior_tree/CMakeLists.txt index 6d43f4e760e..ca7da7f1614 100644 --- a/nav2_behavior_tree/CMakeLists.txt +++ b/nav2_behavior_tree/CMakeLists.txt @@ -2,21 +2,22 @@ cmake_minimum_required(VERSION 3.5) project(nav2_behavior_tree CXX) find_package(ament_cmake REQUIRED) +find_package(behaviortree_cpp REQUIRED) +find_package(builtin_interfaces REQUIRED) +find_package(cv_bridge REQUIRED) +find_package(geometry_msgs REQUIRED) +find_package(nav_msgs REQUIRED) find_package(nav2_common REQUIRED) +find_package(nav2_msgs REQUIRED) +find_package(nav2_util REQUIRED) find_package(rclcpp REQUIRED) find_package(rclcpp_action REQUIRED) find_package(rclcpp_lifecycle REQUIRED) -find_package(builtin_interfaces REQUIRED) -find_package(geometry_msgs REQUIRED) find_package(sensor_msgs REQUIRED) -find_package(nav2_msgs REQUIRED) -find_package(nav_msgs REQUIRED) -find_package(behaviortree_cpp REQUIRED) -find_package(tf2_ros REQUIRED) -find_package(tf2_geometry_msgs REQUIRED) find_package(std_msgs REQUIRED) find_package(std_srvs REQUIRED) -find_package(nav2_util REQUIRED) +find_package(tf2_geometry_msgs REQUIRED) +find_package(tf2_ros REQUIRED) nav2_package() @@ -29,20 +30,21 @@ include_directories( set(library_name ${PROJECT_NAME}) set(dependencies + behaviortree_cpp + cv_bridge + geometry_msgs + nav_msgs + nav2_msgs + nav2_util rclcpp rclcpp_action rclcpp_lifecycle - geometry_msgs sensor_msgs - nav2_msgs - nav_msgs - behaviortree_cpp - tf2 - tf2_ros - tf2_geometry_msgs std_msgs std_srvs - nav2_util + tf2 + tf2_geometry_msgs + tf2_ros ) add_library(${library_name} SHARED @@ -119,6 +121,9 @@ list(APPEND plugin_libs nav2_goal_updated_condition_bt_node) add_library(nav2_is_path_valid_condition_bt_node SHARED plugins/condition/is_path_valid_condition.cpp) list(APPEND plugin_libs nav2_is_path_valid_condition_bt_node) +add_library(nav2_ml_vicinity_condition_bt_node SHARED plugins/condition/ml_vicinity_condition.cpp) +list(APPEND plugin_libs nav2_ml_vicinity_condition_bt_node) + add_library(nav2_time_expired_condition_bt_node SHARED plugins/condition/time_expired_condition.cpp) list(APPEND plugin_libs nav2_time_expired_condition_bt_node) diff --git a/nav2_behavior_tree/include/nav2_behavior_tree/plugins/condition/is_path_valid_condition.hpp b/nav2_behavior_tree/include/nav2_behavior_tree/plugins/condition/is_path_valid_condition.hpp index 5a9f255a9b7..e0761863e97 100644 --- a/nav2_behavior_tree/include/nav2_behavior_tree/plugins/condition/is_path_valid_condition.hpp +++ b/nav2_behavior_tree/include/nav2_behavior_tree/plugins/condition/is_path_valid_condition.hpp @@ -20,7 +20,6 @@ #include "rclcpp/rclcpp.hpp" #include "behaviortree_cpp/condition_node.h" -#include "geometry_msgs/msg/pose_stamped.hpp" #include "nav2_msgs/srv/is_path_valid.hpp" namespace nav2_behavior_tree @@ -70,7 +69,7 @@ class IsPathValidCondition : public BT::ConditionNode private: rclcpp::Node::SharedPtr node_; rclcpp::Client::SharedPtr client_; - // The timeout value while waiting for a responce from the + // The timeout value while waiting for a response from the // is path valid service std::chrono::milliseconds server_timeout_; bool initialized_; diff --git a/nav2_behavior_tree/include/nav2_behavior_tree/plugins/condition/ml_vicinity_condition.hpp b/nav2_behavior_tree/include/nav2_behavior_tree/plugins/condition/ml_vicinity_condition.hpp new file mode 100644 index 00000000000..4c999c65979 --- /dev/null +++ b/nav2_behavior_tree/include/nav2_behavior_tree/plugins/condition/ml_vicinity_condition.hpp @@ -0,0 +1,92 @@ +// Copyright (c) 2024 Andy Zelenak +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef NAV2_BEHAVIOR_TREE__PLUGINS__CONDITION__ML_VICINITY_CONDITION_HPP_ +#define NAV2_BEHAVIOR_TREE__PLUGINS__CONDITION__ML_VICINITY_CONDITION_HPP_ + +#include "behaviortree_cpp/condition_node.h" +#include "rclcpp/rclcpp.hpp" +#include "sensor_msgs/msg/image.hpp" +#include + +namespace nav2_behavior_tree +{ + +/** + * @brief A BT::ConditionNode, returns SUCCESS when a large language model says the vicinity is clear + * The input data has type sensor_msgs::msg::Image. + */ +class MLVicinityCondition : public BT::ConditionNode +{ +public: + /** + * @brief A constructor for nav2_behavior_tree::MLVicinityCondition + * @param condition_name Name for the XML tag for this node + * @param conf BT node configuration + */ + MLVicinityCondition( + const std::string & condition_name, + const BT::NodeConfiguration & conf); + + MLVicinityCondition() = delete; + + /** + * @brief The main override required by a BT action + * @return BT::NodeStatus Status of tick execution + */ + BT::NodeStatus tick() override; + + /** + * @brief Creates list of BT ports + * @return BT::PortsList Containing node-specific ports + */ + static BT::PortsList providedPorts() + { + return { + BT::InputPort("image_topic", "Image topic which is subscribed to"), + BT::InputPort("server_timeout") + }; + } + +private: + /** + * @brief Capture the latest image to send to the ML model + * + * @param msg + */ + void imageCallback(const sensor_msgs::msg::Image& msg) + { + const std::lock_guard lock(image_mutex_); + latest_image_ = cv_bridge::toCvShare(msg, sensor_msgs::image_encodings::TYPE_32FC1); + } + + /** + * @brief Send the prompt (or otherwise run) the AI model. This includes the latest image as input. + * + * @return true if successful + */ + [[nodiscard]] bool promptAIModel(); + + rclcpp::Node::SharedPtr node_; + // The timeout value while waiting for a response from the + // is path valid service + std::chrono::milliseconds server_timeout_; + rclcpp::Subscription::SharedPtr image_sub_; + mutable std::mutex image_mutex_; + std::optional latest_image_; +}; + +} // namespace nav2_behavior_tree + +#endif // NAV2_BEHAVIOR_TREE__PLUGINS__CONDITION__ML_VICINITY_CONDITION_ diff --git a/nav2_behavior_tree/plugins/condition/ml_vicinity_condition.cpp b/nav2_behavior_tree/plugins/condition/ml_vicinity_condition.cpp new file mode 100644 index 00000000000..943c04f3565 --- /dev/null +++ b/nav2_behavior_tree/plugins/condition/ml_vicinity_condition.cpp @@ -0,0 +1,135 @@ +// Copyright (c) 2024 Andy Zelenak +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "nav2_behavior_tree/plugins/condition/ml_vicinity_condition.hpp" +#include +#include +#include +#include +#include +#include + +namespace nav2_behavior_tree +{ + +MLVicinityCondition::MLVicinityCondition( + const std::string & condition_name, + const BT::NodeConfiguration & conf) +: BT::ConditionNode(condition_name, conf) +{ + node_ = config().blackboard->get("node"); + + server_timeout_ = config().blackboard->template get("server_timeout"); + getInput("server_timeout", server_timeout_); + std::string image_topic; + getInput("image_topic", image_topic); + image_sub_ = node_->create_subscription( + image_topic, 10, std::bind(&MLVicinityCondition::imageCallback, + this, std::placeholders::_1)); +} + +BT::NodeStatus MLVicinityCondition::tick() +{ + if (promptAIModel()) { + { + return BT::NodeStatus::SUCCESS; + } + return BT::NodeStatus::FAILURE; +} + +bool MLVicinityCondition::promptAIModel() +{ + // Get OpenAI key from environment variable + const char * openai_key = std::getenv("OPENAI_API_KEY"); + RCLCPP_INFO_STREAM(node_->get_logger(), openai_key); + + const std::string base_url = "https://api.openai.com/v1/chat/completions"; + std::string response; + CURL * curl = curl_easy_init(); + + const std::string prompt = "Are there any obstacles in the image which would prevent forward " + "motion or possibly damage a robot as it moves forward? Note that a small doorsill is not a " + "significant obstacle. Please answer in one word, yes or no."; + + // Lambda to handle service response + auto write_callback = [&](void * contents, size_t size, size_t nmemb, std::string * response) + { + size_t total_size = size * nmemb; + response->append((char *)contents, total_size); + return total_size; + }; + + if (curl) { + // See https://drqwrites.medium.com/accessing-the-openai-api-using-c-a3e527b6584b + nlohmann::json request_data; + request_data["model"] = "gpt-4o"; + request_data["messages"][0]["role"] = "user"; + // Attach the latest image + std::string encoded_image = base64_encode(enc_msg, buf.size()); + { + const std::lock_guard lock(image_mutex_); + // If we have no image yet + if (!latest_image_) { + return false; + } + std::vector buf; + cv::imencode(".jpg", img, buf); + auto *enc_msg = reinterpret_cast(buf.data()); + encoded_image = base64_encode(enc_msg, buf.size()); + } + request_data["messages"][0]["content"] = {prompt, encoded_image}; + + std::string request_data_str = request_data.dump().c_str(); + + struct curl_slist * headers = NULL; + headers = curl_slist_append(headers, "Content-Type: application/json"); + headers = curl_slist_append(headers, + ("Authorization: Bearer " + std::string(openai_key)).c_str()); + curl_easy_setopt(curl, CURLOPT_URL, base_url.c_str()); + curl_easy_setopt(curl, CURLOPT_POSTFIELDS, request_data_str.c_str()); + curl_easy_setopt(curl, CURLOPT_POSTFIELDSIZE, request_data_str.length()); + curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers); + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, write_callback); + curl_easy_setopt(curl, CURLOPT_WRITEDATA, &response); + CURLcode res = curl_easy_perform(curl); + + if (res != CURLE_OK) { + RCLCPP_ERROR_STREAM(node_->get_logger(), "Curl request failed: " << curl_easy_strerror(res)); + } + + curl_easy_cleanup(curl); + curl_slist_free_all(headers); + } + + nlohmann::json jresponse = nlohmann::json::parse(response); + + std::string string_response = jresponse["choices"][0]["message"]["content"].get(); + + // Parse for a one-word Yes/No reply + // Yes means there is an obstacle + if ((string_response.find("Yes") != std::string::npos) || + (string_response.find("yes") != std::string::npos)) + { + return false; + } + return true; +} + +} // namespace nav2_behavior_tree + +#include "behaviortree_cpp_v3/bt_factory.h" +BT_REGISTER_NODES(factory) +{ + factory.registerNodeType("MLVicinityCondition"); +}