From 7d793f4a02f42496cb23a6f8e18d623ae724e72a Mon Sep 17 00:00:00 2001 From: AndyZe Date: Mon, 24 Jun 2024 22:16:45 -0500 Subject: [PATCH] Add a BT condition to check robot's vicinity with an AI model (ChatGPT) Signed-off-by: AndyZe --- nav2_behavior_tree/CMakeLists.txt | 3 + .../condition/is_path_valid_condition.hpp | 3 +- .../condition/ml_vicinity_condition.hpp | 93 ++++++++++++++ .../condition/ml_vicinity_condition.cpp | 121 ++++++++++++++++++ 4 files changed, 218 insertions(+), 2 deletions(-) create mode 100644 nav2_behavior_tree/include/nav2_behavior_tree/plugins/condition/ml_vicinity_condition.hpp create mode 100644 nav2_behavior_tree/plugins/condition/ml_vicinity_condition.cpp diff --git a/nav2_behavior_tree/CMakeLists.txt b/nav2_behavior_tree/CMakeLists.txt index 6d43f4e760e..cadb47a94f5 100644 --- a/nav2_behavior_tree/CMakeLists.txt +++ b/nav2_behavior_tree/CMakeLists.txt @@ -119,6 +119,9 @@ list(APPEND plugin_libs nav2_goal_updated_condition_bt_node) add_library(nav2_is_path_valid_condition_bt_node SHARED plugins/condition/is_path_valid_condition.cpp) list(APPEND plugin_libs nav2_is_path_valid_condition_bt_node) +add_library(nav2_ml_vicinity_condition_bt_node SHARED plugins/condition/ml_vicinity_condition.cpp) +list(APPEND plugin_libs nav2_ml_vicinity_condition_bt_node) + add_library(nav2_time_expired_condition_bt_node SHARED plugins/condition/time_expired_condition.cpp) list(APPEND plugin_libs nav2_time_expired_condition_bt_node) diff --git a/nav2_behavior_tree/include/nav2_behavior_tree/plugins/condition/is_path_valid_condition.hpp b/nav2_behavior_tree/include/nav2_behavior_tree/plugins/condition/is_path_valid_condition.hpp index 5a9f255a9b7..e0761863e97 100644 --- a/nav2_behavior_tree/include/nav2_behavior_tree/plugins/condition/is_path_valid_condition.hpp +++ b/nav2_behavior_tree/include/nav2_behavior_tree/plugins/condition/is_path_valid_condition.hpp @@ -20,7 +20,6 @@ #include "rclcpp/rclcpp.hpp" #include "behaviortree_cpp/condition_node.h" -#include "geometry_msgs/msg/pose_stamped.hpp" #include "nav2_msgs/srv/is_path_valid.hpp" namespace nav2_behavior_tree @@ -70,7 +69,7 @@ class IsPathValidCondition : public BT::ConditionNode private: rclcpp::Node::SharedPtr node_; rclcpp::Client::SharedPtr client_; - // The timeout value while waiting for a responce from the + // The timeout value while waiting for a response from the // is path valid service std::chrono::milliseconds server_timeout_; bool initialized_; diff --git a/nav2_behavior_tree/include/nav2_behavior_tree/plugins/condition/ml_vicinity_condition.hpp b/nav2_behavior_tree/include/nav2_behavior_tree/plugins/condition/ml_vicinity_condition.hpp new file mode 100644 index 00000000000..0dd13627968 --- /dev/null +++ b/nav2_behavior_tree/include/nav2_behavior_tree/plugins/condition/ml_vicinity_condition.hpp @@ -0,0 +1,93 @@ +// Copyright (c) 2024 Andy Zelenak +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef NAV2_BEHAVIOR_TREE__PLUGINS__CONDITION__ML_VICINITY_CONDITION_ +#define NAV2_BEHAVIOR_TREE__PLUGINS__CONDITION__ML_VICINITY_CONDITION_ + +#include + +#include "rclcpp/rclcpp.hpp" +#include "behaviortree_cpp_v3/condition_node.h" +#include "sensor_msgs/msg/image.hpp" + +namespace nav2_behavior_tree +{ + +/** + * @brief A BT::ConditionNode that returns SUCCESS when a large language model says that a path is clear + * The input data has type sensor_msgs::msg::Image. + */ +class MLVicinityCondition : public BT::ConditionNode +{ +public: + /** + * @brief A constructor for nav2_behavior_tree::MLVicinityCondition + * @param condition_name Name for the XML tag for this node + * @param conf BT node configuration + */ + MLVicinityCondition( + const std::string & condition_name, + const BT::NodeConfiguration & conf); + + MLVicinityCondition() = delete; + + /** + * @brief The main override required by a BT action + * @return BT::NodeStatus Status of tick execution + */ + BT::NodeStatus tick() override; + + /** + * @brief Creates list of BT ports + * @return BT::PortsList Containing node-specific ports + */ + static BT::PortsList providedPorts() + { + return { + BT::InputPort("image_topic", "Image topic which is subscribed to"), + BT::InputPort("server_timeout") + }; + } + +private: + /** + * @brief Capture the latest image to send to the ML model + * + * @param msg + */ + void imageCallback(const sensor_msgs::msg::Image& msg) + { + const std::lock_guard lock(image_mutex_); + latest_image_ = msg; + } + + /** + * @brief Send the prompt (or otherwise run) the AI model. This includes the latest image as input. + * + * @return true if successful + */ + [[nodiscard]] bool promptAIModel(); + + rclcpp::Node::SharedPtr node_; + // The timeout value while waiting for a response from the + // is path valid service + std::chrono::milliseconds server_timeout_; + rclcpp::Subscription::SharedPtr image_sub_; + mutable std::mutex image_mutex_; + std::optional latest_image_; +}; + +} // namespace nav2_behavior_tree + +#endif // NAV2_BEHAVIOR_TREE__PLUGINS__CONDITION__ML_VICINITY_CONDITION_ diff --git a/nav2_behavior_tree/plugins/condition/ml_vicinity_condition.cpp b/nav2_behavior_tree/plugins/condition/ml_vicinity_condition.cpp new file mode 100644 index 00000000000..b31fada9279 --- /dev/null +++ b/nav2_behavior_tree/plugins/condition/ml_vicinity_condition.cpp @@ -0,0 +1,121 @@ +// Copyright (c) 2024 Andy Zelenak +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "nav2_behavior_tree/plugins/condition/ml_vicinity_condition.hpp" +#include +#include +#include +#include +#include + +namespace nav2_behavior_tree +{ + +MLVicinityCondition::MLVicinityCondition( + const std::string & condition_name, + const BT::NodeConfiguration & conf) +: BT::ConditionNode(condition_name, conf) +{ + node_ = config().blackboard->get("node"); + + server_timeout_ = config().blackboard->template get("server_timeout"); + getInput("server_timeout", server_timeout_); + std::string image_topic; + getInput("image_topic", image_topic); + image_sub_ = node_->create_subscription( + image_topic, 10, std::bind(&MLVicinityCondition::imageCallback, this, std::placeholders::_1)); +} + +BT::NodeStatus MLVicinityCondition::tick() +{ + if (promptAIModel()) + { + return BT::NodeStatus::SUCCESS; + } + return BT::NodeStatus::FAILURE; +} + +bool MLVicinityCondition::promptAIModel() +{ + // Get OpenAI key from environment variable + const char* openai_key = std::getenv("OPENAI_API_KEY"); + RCLCPP_INFO_STREAM(node_->get_logger(), openai_key); + + const std::string base_url = "https://api.openai.com/v1/chat/completions"; + std::string response; + CURL* curl = curl_easy_init(); + + const std::string prompt = "Are there any obstacles in the image which would prevent forward motion or possibly " + "damage a robot as it moves forward? Note that a small doorsill is not a significant obstacle. Please answer in " + "one word, yes or no."; + + // Lambda to handle service response + auto write_callback = [&](void* contents, size_t size, size_t nmemb, std::string* response) + { + size_t total_size = size * nmemb; + response->append((char*)contents, total_size); + return total_size; + }; + + if (curl) + { + // See https://drqwrites.medium.com/accessing-the-openai-api-using-c-a3e527b6584b + nlohmann::json request_data; + request_data["model"] = "gpt-4o"; + request_data["messages"][0]["role"] = "user"; + request_data["messages"][0]["content"] = prompt; + // TODO: attach the image to this image + + std::string request_data_str = request_data.dump().c_str(); + + struct curl_slist* headers = NULL; + headers = curl_slist_append(headers, "Content-Type: application/json"); + headers = curl_slist_append(headers, ("Authorization: Bearer " + std::string(openai_key)).c_str()); + curl_easy_setopt(curl, CURLOPT_URL, base_url.c_str()); + curl_easy_setopt(curl, CURLOPT_POSTFIELDS, request_data_str.c_str()); + curl_easy_setopt(curl, CURLOPT_POSTFIELDSIZE, request_data_str.length()); + curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers); + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, write_callback); + curl_easy_setopt(curl, CURLOPT_WRITEDATA, &response); + CURLcode res = curl_easy_perform(curl); + + if (res != CURLE_OK) + { + RCLCPP_ERROR_STREAM(node_->get_logger(), "Curl request failed: " << curl_easy_strerror(res)); + } + + curl_easy_cleanup(curl); + curl_slist_free_all(headers); + } + + nlohmann::json jresponse = nlohmann::json::parse(response); + + std::string string_response = jresponse["choices"][0]["message"]["content"].get(); + + // Parse for a one-word Yes/No reply + // Yes means the vicinity should be clear + if ((string_response.find("Yes") != std::string::npos) || (string_response.find("yes") != std::string::npos)) + { + return true; + } + return false; +} + +} // namespace nav2_behavior_tree + +#include "behaviortree_cpp_v3/bt_factory.h" +BT_REGISTER_NODES(factory) +{ + factory.registerNodeType("MLVicinityCondition"); +}