Skip to content

Commit

Permalink
Add a BT condition to check robot's vicinity with an AI model (ChatGPT)
Browse files Browse the repository at this point in the history
Signed-off-by: AndyZe <[email protected]>
  • Loading branch information
AndyZe committed Jun 25, 2024
1 parent 2cdfa5b commit 0373ecc
Show file tree
Hide file tree
Showing 4 changed files with 249 additions and 18 deletions.
37 changes: 21 additions & 16 deletions nav2_behavior_tree/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,22 @@ cmake_minimum_required(VERSION 3.5)
project(nav2_behavior_tree CXX)

find_package(ament_cmake REQUIRED)
find_package(behaviortree_cpp REQUIRED)
find_package(builtin_interfaces REQUIRED)
find_package(cv_bridge REQUIRED)
find_package(geometry_msgs REQUIRED)
find_package(nav_msgs REQUIRED)
find_package(nav2_common REQUIRED)
find_package(nav2_msgs REQUIRED)
find_package(nav2_util REQUIRED)
find_package(rclcpp REQUIRED)
find_package(rclcpp_action REQUIRED)
find_package(rclcpp_lifecycle REQUIRED)
find_package(builtin_interfaces REQUIRED)
find_package(geometry_msgs REQUIRED)
find_package(sensor_msgs REQUIRED)
find_package(nav2_msgs REQUIRED)
find_package(nav_msgs REQUIRED)
find_package(behaviortree_cpp REQUIRED)
find_package(tf2_ros REQUIRED)
find_package(tf2_geometry_msgs REQUIRED)
find_package(std_msgs REQUIRED)
find_package(std_srvs REQUIRED)
find_package(nav2_util REQUIRED)
find_package(tf2_geometry_msgs REQUIRED)
find_package(tf2_ros REQUIRED)

nav2_package()

Expand All @@ -29,20 +30,21 @@ include_directories(
set(library_name ${PROJECT_NAME})

set(dependencies
behaviortree_cpp
cv_bridge
geometry_msgs
nav_msgs
nav2_msgs
nav2_util
rclcpp
rclcpp_action
rclcpp_lifecycle
geometry_msgs
sensor_msgs
nav2_msgs
nav_msgs
behaviortree_cpp
tf2
tf2_ros
tf2_geometry_msgs
std_msgs
std_srvs
nav2_util
tf2
tf2_geometry_msgs
tf2_ros
)

add_library(${library_name} SHARED
Expand Down Expand Up @@ -119,6 +121,9 @@ list(APPEND plugin_libs nav2_goal_updated_condition_bt_node)
add_library(nav2_is_path_valid_condition_bt_node SHARED plugins/condition/is_path_valid_condition.cpp)
list(APPEND plugin_libs nav2_is_path_valid_condition_bt_node)

add_library(nav2_ml_vicinity_condition_bt_node SHARED plugins/condition/ml_vicinity_condition.cpp)
list(APPEND plugin_libs nav2_ml_vicinity_condition_bt_node)

add_library(nav2_time_expired_condition_bt_node SHARED plugins/condition/time_expired_condition.cpp)
list(APPEND plugin_libs nav2_time_expired_condition_bt_node)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

#include "rclcpp/rclcpp.hpp"
#include "behaviortree_cpp/condition_node.h"
#include "geometry_msgs/msg/pose_stamped.hpp"
#include "nav2_msgs/srv/is_path_valid.hpp"

namespace nav2_behavior_tree
Expand Down Expand Up @@ -70,7 +69,7 @@ class IsPathValidCondition : public BT::ConditionNode
private:
rclcpp::Node::SharedPtr node_;
rclcpp::Client<nav2_msgs::srv::IsPathValid>::SharedPtr client_;
// The timeout value while waiting for a responce from the
// The timeout value while waiting for a response from the
// is path valid service
std::chrono::milliseconds server_timeout_;
bool initialized_;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
// Copyright (c) 2024 Andy Zelenak
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef NAV2_BEHAVIOR_TREE__PLUGINS__CONDITION__ML_VICINITY_CONDITION_HPP_
#define NAV2_BEHAVIOR_TREE__PLUGINS__CONDITION__ML_VICINITY_CONDITION_HPP_

#include "behaviortree_cpp/condition_node.h"
#include "rclcpp/rclcpp.hpp"
#include "sensor_msgs/msg/image.hpp"
#include <string>

namespace nav2_behavior_tree
{

/**
* @brief A BT::ConditionNode, returns SUCCESS when a large language model says the vicinity is clear
* The input data has type sensor_msgs::msg::Image.
*/
class MLVicinityCondition : public BT::ConditionNode
{
public:
/**
* @brief A constructor for nav2_behavior_tree::MLVicinityCondition
* @param condition_name Name for the XML tag for this node
* @param conf BT node configuration
*/
MLVicinityCondition(
const std::string & condition_name,
const BT::NodeConfiguration & conf);

MLVicinityCondition() = delete;

/**
* @brief The main override required by a BT action
* @return BT::NodeStatus Status of tick execution
*/
BT::NodeStatus tick() override;

/**
* @brief Creates list of BT ports
* @return BT::PortsList Containing node-specific ports
*/
static BT::PortsList providedPorts()
{
return {
BT::InputPort<std::string>("image_topic", "Image topic which is subscribed to"),
BT::InputPort<std::chrono::milliseconds>("server_timeout")
};
}

private:
/**
* @brief Capture the latest image to send to the ML model
*
* @param msg
*/
void imageCallback(const sensor_msgs::msg::Image& msg)
{
const std::lock_guard<std::mutex> lock(image_mutex_);
latest_image_ = cv_bridge::toCvShare(msg, sensor_msgs::image_encodings::TYPE_32FC1);
}

/**
* @brief Send the prompt (or otherwise run) the AI model. This includes the latest image as input.
*
* @return true if successful
*/
[[nodiscard]] bool promptAIModel();

rclcpp::Node::SharedPtr node_;
// The timeout value while waiting for a response from the
// is path valid service
std::chrono::milliseconds server_timeout_;
rclcpp::Subscription<sensor_msgs::msg::Image>::SharedPtr image_sub_;
mutable std::mutex image_mutex_;
std::optional<cv_bridge::CvImageConstPtr> latest_image_;
};

} // namespace nav2_behavior_tree

#endif // NAV2_BEHAVIOR_TREE__PLUGINS__CONDITION__ML_VICINITY_CONDITION_
135 changes: 135 additions & 0 deletions nav2_behavior_tree/plugins/condition/ml_vicinity_condition.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
// Copyright (c) 2024 Andy Zelenak
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "nav2_behavior_tree/plugins/condition/ml_vicinity_condition.hpp"
#include <chrono>
#include <cstdlib>
#include <cv_bridge/cv_bridge.h>
#include <curl/curl.h>
#include <nlohmann/json.hpp>
#include <string>

namespace nav2_behavior_tree
{

MLVicinityCondition::MLVicinityCondition(
const std::string & condition_name,
const BT::NodeConfiguration & conf)
: BT::ConditionNode(condition_name, conf)
{
node_ = config().blackboard->get<rclcpp::Node::SharedPtr>("node");

server_timeout_ = config().blackboard->template get<std::chrono::milliseconds>("server_timeout");
getInput<std::chrono::milliseconds>("server_timeout", server_timeout_);
std::string image_topic;
getInput<std::string>("image_topic", image_topic);
image_sub_ = node_->create_subscription<sensor_msgs::msg::Image>(
image_topic, 10, std::bind(&MLVicinityCondition::imageCallback,
this, std::placeholders::_1));
}

BT::NodeStatus MLVicinityCondition::tick()
{
if (promptAIModel()) {
{
return BT::NodeStatus::SUCCESS;
}
return BT::NodeStatus::FAILURE;
}

bool MLVicinityCondition::promptAIModel()
{
// Get OpenAI key from environment variable
const char * openai_key = std::getenv("OPENAI_API_KEY");
RCLCPP_INFO_STREAM(node_->get_logger(), openai_key);

const std::string base_url = "https://api.openai.com/v1/chat/completions";
std::string response;
CURL * curl = curl_easy_init();

const std::string prompt = "Are there any obstacles in the image which would prevent forward "
"motion or possibly damage a robot as it moves forward? Note that a small doorsill is not a "
"significant obstacle. Please answer in one word, yes or no.";

// Lambda to handle service response
auto write_callback = [&](void * contents, size_t size, size_t nmemb, std::string * response)
{
size_t total_size = size * nmemb;
response->append((char *)contents, total_size);
return total_size;
};

if (curl) {
// See https://drqwrites.medium.com/accessing-the-openai-api-using-c-a3e527b6584b
nlohmann::json request_data;
request_data["model"] = "gpt-4o";
request_data["messages"][0]["role"] = "user";
// Attach the latest image
std::string encoded_image = base64_encode(enc_msg, buf.size());
{
const std::lock_guard<std::mutex> lock(image_mutex_);
// If we have no image yet
if (!latest_image_) {
return false;
}
std::vector<uchar> buf;
cv::imencode(".jpg", img, buf);
auto *enc_msg = reinterpret_cast<unsigned char*>(buf.data());
encoded_image = base64_encode(enc_msg, buf.size());
}
request_data["messages"][0]["content"] = {prompt, encoded_image};

std::string request_data_str = request_data.dump().c_str();

struct curl_slist * headers = NULL;
headers = curl_slist_append(headers, "Content-Type: application/json");
headers = curl_slist_append(headers,
("Authorization: Bearer " + std::string(openai_key)).c_str());
curl_easy_setopt(curl, CURLOPT_URL, base_url.c_str());
curl_easy_setopt(curl, CURLOPT_POSTFIELDS, request_data_str.c_str());
curl_easy_setopt(curl, CURLOPT_POSTFIELDSIZE, request_data_str.length());
curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers);
curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, write_callback);
curl_easy_setopt(curl, CURLOPT_WRITEDATA, &response);
CURLcode res = curl_easy_perform(curl);

if (res != CURLE_OK) {
RCLCPP_ERROR_STREAM(node_->get_logger(), "Curl request failed: " << curl_easy_strerror(res));
}

curl_easy_cleanup(curl);
curl_slist_free_all(headers);
}

nlohmann::json jresponse = nlohmann::json::parse(response);

std::string string_response = jresponse["choices"][0]["message"]["content"].get<std::string>();

// Parse for a one-word Yes/No reply
// Yes means there is an obstacle
if ((string_response.find("Yes") != std::string::npos) ||
(string_response.find("yes") != std::string::npos))
{
return false;
}
return true;
}

} // namespace nav2_behavior_tree

#include "behaviortree_cpp_v3/bt_factory.h"
BT_REGISTER_NODES(factory)
{
factory.registerNodeType<nav2_behavior_tree::MLVicinityCondition>("MLVicinityCondition");
}

0 comments on commit 0373ecc

Please sign in to comment.