-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add a BT condition to check robot's vicinity with an AI model (ChatGPT)
Signed-off-by: AndyZe <[email protected]>
- Loading branch information
Showing
4 changed files
with
244 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
92 changes: 92 additions & 0 deletions
92
nav2_behavior_tree/include/nav2_behavior_tree/plugins/condition/ml_vicinity_condition.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
// Copyright (c) 2024 Andy Zelenak | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#ifndef NAV2_BEHAVIOR_TREE__PLUGINS__CONDITION__ML_VICINITY_CONDITION_HPP_ | ||
#define NAV2_BEHAVIOR_TREE__PLUGINS__CONDITION__ML_VICINITY_CONDITION_HPP_ | ||
|
||
#include "behaviortree_cpp_v3/condition_node.h" | ||
#include "rclcpp/rclcpp.hpp" | ||
#include "sensor_msgs/msg/image.hpp" | ||
#include <string> | ||
|
||
namespace nav2_behavior_tree | ||
{ | ||
|
||
/** | ||
* @brief A BT::ConditionNode, returns SUCCESS when a large language model says the vicinity is clear | ||
* The input data has type sensor_msgs::msg::Image. | ||
*/ | ||
class MLVicinityCondition : public BT::ConditionNode | ||
{ | ||
public: | ||
/** | ||
* @brief A constructor for nav2_behavior_tree::MLVicinityCondition | ||
* @param condition_name Name for the XML tag for this node | ||
* @param conf BT node configuration | ||
*/ | ||
MLVicinityCondition( | ||
const std::string & condition_name, | ||
const BT::NodeConfiguration & conf); | ||
|
||
MLVicinityCondition() = delete; | ||
|
||
/** | ||
* @brief The main override required by a BT action | ||
* @return BT::NodeStatus Status of tick execution | ||
*/ | ||
BT::NodeStatus tick() override; | ||
|
||
/** | ||
* @brief Creates list of BT ports | ||
* @return BT::PortsList Containing node-specific ports | ||
*/ | ||
static BT::PortsList providedPorts() | ||
{ | ||
return { | ||
BT::InputPort<std::string>("image_topic", "Image topic which is subscribed to"), | ||
BT::InputPort<std::chrono::milliseconds>("server_timeout") | ||
}; | ||
} | ||
|
||
private: | ||
/** | ||
* @brief Capture the latest image to send to the ML model | ||
* | ||
* @param msg | ||
*/ | ||
void imageCallback(const sensor_msgs::msg::Image& msg) | ||
{ | ||
const std::lock_guard<std::mutex> lock(image_mutex_); | ||
latest_image_ = cv_bridge::toCvShare(msg, sensor_msgs::image_encodings::TYPE_32FC1); | ||
} | ||
|
||
/** | ||
* @brief Send the prompt (or otherwise run) the AI model. This includes the latest image as input. | ||
* | ||
* @return true if successful | ||
*/ | ||
[[nodiscard]] bool promptAIModel(); | ||
|
||
rclcpp::Node::SharedPtr node_; | ||
// The timeout value while waiting for a response from the | ||
// is path valid service | ||
std::chrono::milliseconds server_timeout_; | ||
rclcpp::Subscription<sensor_msgs::msg::Image>::SharedPtr image_sub_; | ||
mutable std::mutex image_mutex_; | ||
std::optional<cv_bridge::CvImageConstPtr> latest_image_; | ||
}; | ||
|
||
} // namespace nav2_behavior_tree | ||
|
||
#endif // NAV2_BEHAVIOR_TREE__PLUGINS__CONDITION__ML_VICINITY_CONDITION_ |
130 changes: 130 additions & 0 deletions
130
nav2_behavior_tree/plugins/condition/ml_vicinity_condition.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
// Copyright (c) 2024 Andy Zelenak | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "nav2_behavior_tree/plugins/condition/ml_vicinity_condition.hpp" | ||
#include <chrono> | ||
#include <cstdlib> | ||
#include <cv_bridge/cv_bridge.h> | ||
#include <curl/curl.h> | ||
#include <nlohmann/json.hpp> | ||
#include <string> | ||
|
||
namespace nav2_behavior_tree | ||
{ | ||
|
||
MLVicinityCondition::MLVicinityCondition( | ||
const std::string & condition_name, | ||
const BT::NodeConfiguration & conf) | ||
: BT::ConditionNode(condition_name, conf) | ||
{ | ||
node_ = config().blackboard->get<rclcpp::Node::SharedPtr>("node"); | ||
|
||
server_timeout_ = config().blackboard->template get<std::chrono::milliseconds>("server_timeout"); | ||
getInput<std::chrono::milliseconds>("server_timeout", server_timeout_); | ||
std::string image_topic; | ||
getInput<std::string>("image_topic", image_topic); | ||
image_sub_ = node_->create_subscription<sensor_msgs::msg::Image>( | ||
image_topic, 10, std::bind(&MLVicinityCondition::imageCallback, | ||
this, std::placeholders::_1)); | ||
} | ||
|
||
BT::NodeStatus MLVicinityCondition::tick() | ||
{ | ||
if (promptAIModel()) { | ||
{ | ||
return BT::NodeStatus::SUCCESS; | ||
} | ||
return BT::NodeStatus::FAILURE; | ||
} | ||
|
||
bool MLVicinityCondition::promptAIModel() | ||
{ | ||
// Get OpenAI key from environment variable | ||
const char * openai_key = std::getenv("OPENAI_API_KEY"); | ||
RCLCPP_INFO_STREAM(node_->get_logger(), openai_key); | ||
|
||
const std::string base_url = "https://api.openai.com/v1/chat/completions"; | ||
std::string response; | ||
CURL * curl = curl_easy_init(); | ||
|
||
const std::string prompt = "Are there any obstacles in the image which would prevent forward " | ||
"motion or possibly damage a robot as it moves forward? Note that a small doorsill is not a " | ||
"significant obstacle. Please answer in one word, yes or no."; | ||
|
||
// Lambda to handle service response | ||
auto write_callback = [&](void * contents, size_t size, size_t nmemb, std::string * response) | ||
{ | ||
size_t total_size = size * nmemb; | ||
response->append((char *)contents, total_size); | ||
return total_size; | ||
}; | ||
|
||
if (curl) { | ||
// See https://drqwrites.medium.com/accessing-the-openai-api-using-c-a3e527b6584b | ||
nlohmann::json request_data; | ||
request_data["model"] = "gpt-4o"; | ||
request_data["messages"][0]["role"] = "user"; | ||
request_data["messages"][0]["content"] = prompt; | ||
// Attach the latest image | ||
{ | ||
const std::lock_guard<std::mutex> lock(image_mutex_); | ||
// If we have no image yet | ||
if (!latest_image_) { | ||
return false; | ||
} | ||
} | ||
|
||
std::string request_data_str = request_data.dump().c_str(); | ||
|
||
struct curl_slist * headers = NULL; | ||
headers = curl_slist_append(headers, "Content-Type: application/json"); | ||
headers = curl_slist_append(headers, | ||
("Authorization: Bearer " + std::string(openai_key)).c_str()); | ||
curl_easy_setopt(curl, CURLOPT_URL, base_url.c_str()); | ||
curl_easy_setopt(curl, CURLOPT_POSTFIELDS, request_data_str.c_str()); | ||
curl_easy_setopt(curl, CURLOPT_POSTFIELDSIZE, request_data_str.length()); | ||
curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers); | ||
curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, write_callback); | ||
curl_easy_setopt(curl, CURLOPT_WRITEDATA, &response); | ||
CURLcode res = curl_easy_perform(curl); | ||
|
||
if (res != CURLE_OK) { | ||
RCLCPP_ERROR_STREAM(node_->get_logger(), "Curl request failed: " << curl_easy_strerror(res)); | ||
} | ||
|
||
curl_easy_cleanup(curl); | ||
curl_slist_free_all(headers); | ||
} | ||
|
||
nlohmann::json jresponse = nlohmann::json::parse(response); | ||
|
||
std::string string_response = jresponse["choices"][0]["message"]["content"].get<std::string>(); | ||
|
||
// Parse for a one-word Yes/No reply | ||
// Yes means there is an obstacle | ||
if ((string_response.find("Yes") != std::string::npos) || | ||
(string_response.find("yes") != std::string::npos)) | ||
{ | ||
return false; | ||
} | ||
return true; | ||
} | ||
|
||
} // namespace nav2_behavior_tree | ||
|
||
#include "behaviortree_cpp_v3/bt_factory.h" | ||
BT_REGISTER_NODES(factory) | ||
{ | ||
factory.registerNodeType<nav2_behavior_tree::MLVicinityCondition>("MLVicinityCondition"); | ||
} |