Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a BT condition to check robot's vicinity with an AI model (currently ChatGPT) #4486

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 21 additions & 16 deletions nav2_behavior_tree/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,22 @@ cmake_minimum_required(VERSION 3.5)
project(nav2_behavior_tree CXX)

find_package(ament_cmake REQUIRED)
find_package(behaviortree_cpp REQUIRED)
find_package(builtin_interfaces REQUIRED)
find_package(cv_bridge REQUIRED)
find_package(geometry_msgs REQUIRED)
find_package(nav_msgs REQUIRED)
find_package(nav2_common REQUIRED)
find_package(nav2_msgs REQUIRED)
find_package(nav2_util REQUIRED)
find_package(rclcpp REQUIRED)
find_package(rclcpp_action REQUIRED)
find_package(rclcpp_lifecycle REQUIRED)
find_package(builtin_interfaces REQUIRED)
find_package(geometry_msgs REQUIRED)
find_package(sensor_msgs REQUIRED)
find_package(nav2_msgs REQUIRED)
find_package(nav_msgs REQUIRED)
find_package(behaviortree_cpp REQUIRED)
find_package(tf2_ros REQUIRED)
find_package(tf2_geometry_msgs REQUIRED)
find_package(std_msgs REQUIRED)
find_package(std_srvs REQUIRED)
find_package(nav2_util REQUIRED)
find_package(tf2_geometry_msgs REQUIRED)
find_package(tf2_ros REQUIRED)

nav2_package()

Expand All @@ -29,20 +30,21 @@ include_directories(
set(library_name ${PROJECT_NAME})

set(dependencies
behaviortree_cpp
cv_bridge
geometry_msgs
nav_msgs
nav2_msgs
nav2_util
rclcpp
rclcpp_action
rclcpp_lifecycle
geometry_msgs
sensor_msgs
nav2_msgs
nav_msgs
behaviortree_cpp
tf2
tf2_ros
tf2_geometry_msgs
std_msgs
std_srvs
nav2_util
tf2
tf2_geometry_msgs
tf2_ros
)

add_library(${library_name} SHARED
Expand Down Expand Up @@ -119,6 +121,9 @@ list(APPEND plugin_libs nav2_goal_updated_condition_bt_node)
add_library(nav2_is_path_valid_condition_bt_node SHARED plugins/condition/is_path_valid_condition.cpp)
list(APPEND plugin_libs nav2_is_path_valid_condition_bt_node)

add_library(nav2_ml_vicinity_condition_bt_node SHARED plugins/condition/ml_vicinity_condition.cpp)
list(APPEND plugin_libs nav2_ml_vicinity_condition_bt_node)

add_library(nav2_time_expired_condition_bt_node SHARED plugins/condition/time_expired_condition.cpp)
list(APPEND plugin_libs nav2_time_expired_condition_bt_node)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

#include "rclcpp/rclcpp.hpp"
#include "behaviortree_cpp/condition_node.h"
#include "geometry_msgs/msg/pose_stamped.hpp"
#include "nav2_msgs/srv/is_path_valid.hpp"

namespace nav2_behavior_tree
Expand Down Expand Up @@ -70,7 +69,7 @@ class IsPathValidCondition : public BT::ConditionNode
private:
rclcpp::Node::SharedPtr node_;
rclcpp::Client<nav2_msgs::srv::IsPathValid>::SharedPtr client_;
// The timeout value while waiting for a responce from the
// The timeout value while waiting for a response from the
// is path valid service
std::chrono::milliseconds server_timeout_;
bool initialized_;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
// Copyright (c) 2024 Andy Zelenak
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef NAV2_BEHAVIOR_TREE__PLUGINS__CONDITION__ML_VICINITY_CONDITION_HPP_
#define NAV2_BEHAVIOR_TREE__PLUGINS__CONDITION__ML_VICINITY_CONDITION_HPP_

#include <string>

#include "behaviortree_cpp/condition_node.h"
#include "rclcpp/rclcpp.hpp"
#include "sensor_msgs/msg/image.hpp"

namespace nav2_behavior_tree
{

/**
* @brief A BT::ConditionNode, returns SUCCESS when a large language model says the vicinity is clear
* The input data has type sensor_msgs::msg::Image.
*/
class MLVicinityCondition : public BT::ConditionNode
{
public:
/**
* @brief A constructor for nav2_behavior_tree::MLVicinityCondition
* @param condition_name Name for the XML tag for this node
* @param conf BT node configuration
*/
MLVicinityCondition(
const std::string & condition_name,
const BT::NodeConfiguration & conf);

MLVicinityCondition() = delete;

/**
* @brief The main override required by a BT action
* @return BT::NodeStatus Status of tick execution
*/
BT::NodeStatus tick() override;

/**
* @brief Creates list of BT ports
* @return BT::PortsList Containing node-specific ports
*/
static BT::PortsList providedPorts()
{
return {
BT::InputPort<std::string>("image_topic", "Image topic which is subscribed to"),
BT::InputPort<std::chrono::milliseconds>("server_timeout")
};
}

private:
/**
* @brief Capture the latest image to send to the ML model
*
* @param msg
*/
void imageCallback(const sensor_msgs::msg::Image& msg)
{
const std::lock_guard<std::mutex> lock(image_mutex_);
latest_image_ = cv_bridge::toCvShare(msg, sensor_msgs::image_encodings::TYPE_32FC1);
}

/**
* @brief Send the prompt (or otherwise run) the AI model. This includes the latest image as input.
*
* @return true if successful
*/
[[nodiscard]] bool promptAIModel();

rclcpp::Node::SharedPtr node_;
// The timeout value while waiting for a response from the
// is path valid service
std::chrono::milliseconds server_timeout_;
rclcpp::Subscription<sensor_msgs::msg::Image>::SharedPtr image_sub_;
mutable std::mutex image_mutex_;
std::optional<cv_bridge::CvImageConstPtr> latest_image_;
};

} // namespace nav2_behavior_tree

#endif // NAV2_BEHAVIOR_TREE__PLUGINS__CONDITION__ML_VICINITY_CONDITION_HPP_
44 changes: 23 additions & 21 deletions nav2_behavior_tree/package.xml
Original file line number Diff line number Diff line change
Expand Up @@ -14,39 +14,41 @@
<build_export_depend>tf2_geometry_msgs</build_export_depend>
<build_export_depend>std_srvs</build_export_depend>

<build_depend>rclcpp</build_depend>
<build_depend>rclcpp_action</build_depend>
<build_depend>rclcpp_lifecycle</build_depend>
<build_depend>behaviortree_cpp</build_depend>
<build_depend>builtin_interfaces</build_depend>
<build_depend>cv_bridge</build_depend>
<build_depend>geometry_msgs</build_depend>
<build_depend>sensor_msgs</build_depend>
<build_depend>nav2_msgs</build_depend>
<build_depend>lifecycle_msgs</build_depend>
<build_depend>nav_msgs</build_depend>
<build_depend>tf2</build_depend>
<build_depend>tf2_ros</build_depend>
<build_depend>tf2_geometry_msgs</build_depend>
<build_depend>nav2_common</build_depend>
<build_depend>nav2_msgs</build_depend>
<build_depend>nav2_util</build_depend>
<build_depend>rclcpp_action</build_depend>
<build_depend>rclcpp_lifecycle</build_depend>
<build_depend>rclcpp</build_depend>
<build_depend>sensor_msgs</build_depend>
<build_depend>std_msgs</build_depend>
<build_depend>std_srvs</build_depend>
<build_depend>nav2_util</build_depend>
<build_depend>lifecycle_msgs</build_depend>
<build_depend>nav2_common</build_depend>
<build_depend>tf2_geometry_msgs</build_depend>
<build_depend>tf2_ros</build_depend>
<build_depend>tf2</build_depend>

<exec_depend>rclcpp</exec_depend>
<exec_depend>rclcpp_action</exec_depend>
<exec_depend>rclcpp_lifecycle</exec_depend>
<exec_depend>std_msgs</exec_depend>
<exec_depend>behaviortree_cpp</exec_depend>
<exec_depend>builtin_interfaces</exec_depend>
<exec_depend>cv_bridge</exec_depend>
<exec_depend>geometry_msgs</exec_depend>
<exec_depend>sensor_msgs</exec_depend>
<exec_depend>nav2_msgs</exec_depend>
<exec_depend>lifecycle_msgs</exec_depend>
<exec_depend>nav_msgs</exec_depend>
<exec_depend>tf2</exec_depend>
<exec_depend>tf2_ros</exec_depend>
<exec_depend>tf2_geometry_msgs</exec_depend>
<exec_depend>nav2_msgs</exec_depend>
<exec_depend>nav2_util</exec_depend>
<exec_depend>lifecycle_msgs</exec_depend>
<exec_depend>rclcpp_action</exec_depend>
<exec_depend>rclcpp_lifecycle</exec_depend>
<exec_depend>rclcpp</exec_depend>
<exec_depend>sensor_msgs</exec_depend>
<exec_depend>std_msgs</exec_depend>
<exec_depend>tf2_geometry_msgs</exec_depend>
<exec_depend>tf2_ros</exec_depend>
<exec_depend>tf2</exec_depend>

<test_depend>ament_lint_common</test_depend>
<test_depend>ament_lint_auto</test_depend>
Expand Down
133 changes: 133 additions & 0 deletions nav2_behavior_tree/plugins/condition/ml_vicinity_condition.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
// Copyright (c) 2024 Andy Zelenak
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "nav2_behavior_tree/plugins/condition/ml_vicinity_condition.hpp"
#include <chrono>
#include <cstdlib>
#include <string>
#include <cv_bridge/cv_bridge.hpp>
#include <curl/curl.h>
#include <nlohmann/json.hpp>

namespace nav2_behavior_tree
{

MLVicinityCondition::MLVicinityCondition(
const std::string & condition_name,
const BT::NodeConfiguration & conf)
: BT::ConditionNode(condition_name, conf)
{
node_ = config().blackboard->get<rclcpp::Node::SharedPtr>("node");

server_timeout_ = config().blackboard->template get<std::chrono::milliseconds>("server_timeout");
getInput<std::chrono::milliseconds>("server_timeout", server_timeout_);
std::string image_topic;
getInput<std::string>("image_topic", image_topic);
image_sub_ = node_->create_subscription<sensor_msgs::msg::Image>(
image_topic, 10, std::bind(&MLVicinityCondition::imageCallback,
this, std::placeholders::_1));
}

BT::NodeStatus MLVicinityCondition::tick()
{
if (promptAIModel()) {
return BT::NodeStatus::SUCCESS;
}
return BT::NodeStatus::FAILURE;
}

bool MLVicinityCondition::promptAIModel()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i figured i'd cram everything related to ChatGPT into this one function so it can be swapped out easily, although it's not the most efficient way

{
// Get OpenAI key from environment variable
const char * openai_key = std::getenv("OPENAI_API_KEY");
RCLCPP_INFO_STREAM(node_->get_logger(), openai_key);

const std::string base_url = "https://api.openai.com/v1/chat/completions";
std::string response;
CURL * curl = curl_easy_init();

const std::string prompt = "Are there any obstacles in the image which would prevent forward "
"motion or possibly damage a robot as it moves forward? Note that a small doorsill is not a "
"significant obstacle. Please answer in one word, yes or no.";

// Lambda to handle service response
auto write_callback = [&](void * contents, size_t size, size_t nmemb, std::string * response)
{
size_t total_size = size * nmemb;
response->append(reinterpret_cast<char *>(contents), total_size);
return total_size;
};

if (curl) {
// See https://drqwrites.medium.com/accessing-the-openai-api-using-c-a3e527b6584b
nlohmann::json request_data;
request_data["model"] = "gpt-4o";
request_data["messages"][0]["role"] = "user";
// Attach the latest image
std::string encoded_image = base64_encode(enc_msg, buf.size());
{
const std::lock_guard<std::mutex> lock(image_mutex_);
// If we have no image yet
if (!latest_image_) {
return false;
}
std::vector<uchar> buf;
cv::imencode(".jpg", img, buf);
auto *enc_msg = reinterpret_cast<unsigned char*>(buf.data());
encoded_image = base64_encode(enc_msg, buf.size());
}
request_data["messages"][0]["content"] = {prompt, encoded_image};

std::string request_data_str = request_data.dump().c_str();

struct curl_slist * headers = NULL;
headers = curl_slist_append(headers, "Content-Type: application/json");
headers = curl_slist_append(headers,
("Authorization: Bearer " + std::string(openai_key)).c_str());
curl_easy_setopt(curl, CURLOPT_URL, base_url.c_str());
curl_easy_setopt(curl, CURLOPT_POSTFIELDS, request_data_str.c_str());
curl_easy_setopt(curl, CURLOPT_POSTFIELDSIZE, request_data_str.length());
curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers);
curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, write_callback);
curl_easy_setopt(curl, CURLOPT_WRITEDATA, &response);
CURLcode res = curl_easy_perform(curl);

if (res != CURLE_OK) {
RCLCPP_ERROR_STREAM(node_->get_logger(), "Curl request failed: " << curl_easy_strerror(res));
}

curl_easy_cleanup(curl);
curl_slist_free_all(headers);
}

nlohmann::json jresponse = nlohmann::json::parse(response);

std::string string_response = jresponse["choices"][0]["message"]["content"].get<std::string>();

// Parse for a one-word Yes/No reply
// Yes means there is an obstacle
if ((string_response.find("Yes") != std::string::npos) ||
(string_response.find("yes") != std::string::npos))
{
return false;
}
return true;
}
} // namespace nav2_behavior_tree

#include "behaviortree_cpp/bt_factory.h"
BT_REGISTER_NODES(factory)
{
factory.registerNodeType<nav2_behavior_tree::MLVicinityCondition>("MLVicinityCondition");
}
Loading