Skip to content

Commit

Permalink
Add a BT condition to check robot's vicinity with an AI model (ChatGPT)
Browse files Browse the repository at this point in the history
  • Loading branch information
AndyZe committed Jun 25, 2024
1 parent 2cdfa5b commit 2e31b4e
Show file tree
Hide file tree
Showing 4 changed files with 218 additions and 2 deletions.
3 changes: 3 additions & 0 deletions nav2_behavior_tree/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,9 @@ list(APPEND plugin_libs nav2_goal_updated_condition_bt_node)
add_library(nav2_is_path_valid_condition_bt_node SHARED plugins/condition/is_path_valid_condition.cpp)
list(APPEND plugin_libs nav2_is_path_valid_condition_bt_node)

add_library(nav2_ml_model_is_path_valid_condition_bt_node SHARED plugins/condition/ml_model_is_path_valid_condition.cpp)
list(APPEND plugin_libs nav2_ml_model_is_path_valid_condition_bt_node)

add_library(nav2_time_expired_condition_bt_node SHARED plugins/condition/time_expired_condition.cpp)
list(APPEND plugin_libs nav2_time_expired_condition_bt_node)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

#include "rclcpp/rclcpp.hpp"
#include "behaviortree_cpp/condition_node.h"
#include "geometry_msgs/msg/pose_stamped.hpp"
#include "nav2_msgs/srv/is_path_valid.hpp"

namespace nav2_behavior_tree
Expand Down Expand Up @@ -70,7 +69,7 @@ class IsPathValidCondition : public BT::ConditionNode
private:
rclcpp::Node::SharedPtr node_;
rclcpp::Client<nav2_msgs::srv::IsPathValid>::SharedPtr client_;
// The timeout value while waiting for a responce from the
// The timeout value while waiting for a response from the
// is path valid service
std::chrono::milliseconds server_timeout_;
bool initialized_;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
// Copyright (c) 2024 Andy Zelenak
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef NAV2_BEHAVIOR_TREE__PLUGINS__CONDITION__ML_VICINITY_CONDITION_
#define NAV2_BEHAVIOR_TREE__PLUGINS__CONDITION__ML_VICINITY_CONDITION_

#include <string>

#include "rclcpp/rclcpp.hpp"
#include "behaviortree_cpp_v3/condition_node.h"
#include "sensor_msgs/msg/image.hpp"

namespace nav2_behavior_tree
{

/**
* @brief A BT::ConditionNode that returns SUCCESS when a large language model says that a path is clear
* The input data has type sensor_msgs::msg::Image.
*/
class MLVicinityCondition : public BT::ConditionNode
{
public:
/**
* @brief A constructor for nav2_behavior_tree::MLVicinityCondition
* @param condition_name Name for the XML tag for this node
* @param conf BT node configuration
*/
MLVicinityCondition(
const std::string & condition_name,
const BT::NodeConfiguration & conf);

MLVicinityCondition() = delete;

/**
* @brief The main override required by a BT action
* @return BT::NodeStatus Status of tick execution
*/
BT::NodeStatus tick() override;

/**
* @brief Creates list of BT ports
* @return BT::PortsList Containing node-specific ports
*/
static BT::PortsList providedPorts()
{
return {
BT::InputPort<std::string>("image_topic", "Image topic which is subscribed to"),
BT::InputPort<std::chrono::milliseconds>("server_timeout")
};
}

private:
/**
* @brief Capture the latest image to send to the ML model
*
* @param msg
*/
void imageCallback(const sensor_msgs::msg::Image& msg)
{
const std::lock_guard<std::mutex> lock(image_mutex_);
latest_image_ = msg;
}

/**
* @brief Send the prompt (or otherwise run) the AI model. This includes the latest image as input.
*
* @return true if successful
*/
[[nodiscard]] bool promptAIModel();

rclcpp::Node::SharedPtr node_;
// The timeout value while waiting for a response from the
// is path valid service
std::chrono::milliseconds server_timeout_;
rclcpp::Subscription<sensor_msgs::msg::Image>::SharedPtr image_sub_;
mutable std::mutex image_mutex_;
std::optional<sensor_msgs::msg::Image> latest_image_;
};

} // namespace nav2_behavior_tree

#endif // NAV2_BEHAVIOR_TREE__PLUGINS__CONDITION__ML_VICINITY_CONDITION_
121 changes: 121 additions & 0 deletions nav2_behavior_tree/plugins/condition/ml_vicinity_condition.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
// Copyright (c) 2024 Andy Zelenak
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "nav2_behavior_tree/plugins/condition/ml_vicinity_condition.hpp"
#include <chrono>
#include <cstdlib>
#include <curl/curl.h>
#include <nlohmann/json.hpp>
#include <string>

namespace nav2_behavior_tree
{

MLVicinityCondition::MLVicinityCondition(
const std::string & condition_name,
const BT::NodeConfiguration & conf)
: BT::ConditionNode(condition_name, conf)
{
node_ = config().blackboard->get<rclcpp::Node::SharedPtr>("node");

server_timeout_ = config().blackboard->template get<std::chrono::milliseconds>("server_timeout");
getInput<std::chrono::milliseconds>("server_timeout", server_timeout_);
std::string image_topic;
getInput<std::string>("image_topic", image_topic);
image_sub_ = node_->create_subscription<sensor_msgs::msg::Image>(
image_topic, 10, std::bind(&MLVicinityCondition::imageCallback, this, std::placeholders::_1));
}

BT::NodeStatus MLVicinityCondition::tick()
{
if (promptAIModel())
{
return BT::NodeStatus::SUCCESS;
}
return BT::NodeStatus::FAILURE;
}

bool MLVicinityCondition::promptAIModel()
{
// Get OpenAI key from environment variable
const char* openai_key = std::getenv("OPENAI_API_KEY");
RCLCPP_INFO_STREAM(node_->get_logger(), openai_key);

const std::string base_url = "https://api.openai.com/v1/chat/completions";
std::string response;
CURL* curl = curl_easy_init();

const std::string prompt = "Are there any obstacles in the image which would prevent forward motion or possibly "
"damage a robot as it moves forward? Note that a small doorsill is not a significant obstacle. Please answer in "
"one word, yes or no.";

// Lambda to handle service response
auto write_callback = [&](void* contents, size_t size, size_t nmemb, std::string* response)
{
size_t total_size = size * nmemb;
response->append((char*)contents, total_size);
return total_size;
};

if (curl)
{
// See https://drqwrites.medium.com/accessing-the-openai-api-using-c-a3e527b6584b
nlohmann::json request_data;
request_data["model"] = "gpt-4o";
request_data["messages"][0]["role"] = "user";
request_data["messages"][0]["content"] = prompt;
// TODO: attach the image to this image

std::string request_data_str = request_data.dump().c_str();

struct curl_slist* headers = NULL;
headers = curl_slist_append(headers, "Content-Type: application/json");
headers = curl_slist_append(headers, ("Authorization: Bearer " + std::string(openai_key)).c_str());
curl_easy_setopt(curl, CURLOPT_URL, base_url.c_str());
curl_easy_setopt(curl, CURLOPT_POSTFIELDS, request_data_str.c_str());
curl_easy_setopt(curl, CURLOPT_POSTFIELDSIZE, request_data_str.length());
curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers);
curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, write_callback);
curl_easy_setopt(curl, CURLOPT_WRITEDATA, &response);
CURLcode res = curl_easy_perform(curl);

if (res != CURLE_OK)
{
RCLCPP_ERROR_STREAM(node_->get_logger(), "Curl request failed: " << curl_easy_strerror(res));
}

curl_easy_cleanup(curl);
curl_slist_free_all(headers);
}

nlohmann::json jresponse = nlohmann::json::parse(response);

std::string string_response = jresponse["choices"][0]["message"]["content"].get<std::string>();

// Parse for a one-word Yes/No reply
// Yes means the vicinity should be clear
if ((string_response.find("Yes") != std::string::npos) || (string_response.find("yes") != std::string::npos))
{
return true;
}
return false;
}

} // namespace nav2_behavior_tree

#include "behaviortree_cpp_v3/bt_factory.h"
BT_REGISTER_NODES(factory)
{
factory.registerNodeType<nav2_behavior_tree::MLVicinityCondition>("MLVicinityCondition");
}

0 comments on commit 2e31b4e

Please sign in to comment.