-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a BT condition to check robot's vicinity with an AI model (currently ChatGPT) #4486
Closed
Closed
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
93 changes: 93 additions & 0 deletions
93
nav2_behavior_tree/include/nav2_behavior_tree/plugins/condition/ml_vicinity_condition.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
// Copyright (c) 2024 Andy Zelenak | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#ifndef NAV2_BEHAVIOR_TREE__PLUGINS__CONDITION__ML_VICINITY_CONDITION_HPP_ | ||
#define NAV2_BEHAVIOR_TREE__PLUGINS__CONDITION__ML_VICINITY_CONDITION_HPP_ | ||
|
||
#include <string> | ||
|
||
#include "behaviortree_cpp/condition_node.h" | ||
#include "rclcpp/rclcpp.hpp" | ||
#include "sensor_msgs/msg/image.hpp" | ||
|
||
namespace nav2_behavior_tree | ||
{ | ||
|
||
/** | ||
* @brief A BT::ConditionNode, returns SUCCESS when a large language model says the vicinity is clear | ||
* The input data has type sensor_msgs::msg::Image. | ||
*/ | ||
class MLVicinityCondition : public BT::ConditionNode | ||
{ | ||
public: | ||
/** | ||
* @brief A constructor for nav2_behavior_tree::MLVicinityCondition | ||
* @param condition_name Name for the XML tag for this node | ||
* @param conf BT node configuration | ||
*/ | ||
MLVicinityCondition( | ||
const std::string & condition_name, | ||
const BT::NodeConfiguration & conf); | ||
|
||
MLVicinityCondition() = delete; | ||
|
||
/** | ||
* @brief The main override required by a BT action | ||
* @return BT::NodeStatus Status of tick execution | ||
*/ | ||
BT::NodeStatus tick() override; | ||
|
||
/** | ||
* @brief Creates list of BT ports | ||
* @return BT::PortsList Containing node-specific ports | ||
*/ | ||
static BT::PortsList providedPorts() | ||
{ | ||
return { | ||
BT::InputPort<std::string>("image_topic", "Image topic which is subscribed to"), | ||
BT::InputPort<std::chrono::milliseconds>("server_timeout") | ||
}; | ||
} | ||
|
||
private: | ||
/** | ||
* @brief Capture the latest image to send to the ML model | ||
* | ||
* @param msg | ||
*/ | ||
void imageCallback(const sensor_msgs::msg::Image& msg) | ||
{ | ||
const std::lock_guard<std::mutex> lock(image_mutex_); | ||
latest_image_ = cv_bridge::toCvShare(msg, sensor_msgs::image_encodings::TYPE_32FC1); | ||
} | ||
|
||
/** | ||
* @brief Send the prompt (or otherwise run) the AI model. This includes the latest image as input. | ||
* | ||
* @return true if successful | ||
*/ | ||
[[nodiscard]] bool promptAIModel(); | ||
|
||
rclcpp::Node::SharedPtr node_; | ||
// The timeout value while waiting for a response from the | ||
// is path valid service | ||
std::chrono::milliseconds server_timeout_; | ||
rclcpp::Subscription<sensor_msgs::msg::Image>::SharedPtr image_sub_; | ||
mutable std::mutex image_mutex_; | ||
std::optional<cv_bridge::CvImageConstPtr> latest_image_; | ||
}; | ||
|
||
} // namespace nav2_behavior_tree | ||
|
||
#endif // NAV2_BEHAVIOR_TREE__PLUGINS__CONDITION__ML_VICINITY_CONDITION_HPP_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
133 changes: 133 additions & 0 deletions
133
nav2_behavior_tree/plugins/condition/ml_vicinity_condition.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
// Copyright (c) 2024 Andy Zelenak | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "nav2_behavior_tree/plugins/condition/ml_vicinity_condition.hpp" | ||
#include <chrono> | ||
#include <cstdlib> | ||
#include <string> | ||
#include <cv_bridge/cv_bridge.hpp> | ||
#include <curl/curl.h> | ||
#include <nlohmann/json.hpp> | ||
|
||
namespace nav2_behavior_tree | ||
{ | ||
|
||
MLVicinityCondition::MLVicinityCondition( | ||
const std::string & condition_name, | ||
const BT::NodeConfiguration & conf) | ||
: BT::ConditionNode(condition_name, conf) | ||
{ | ||
node_ = config().blackboard->get<rclcpp::Node::SharedPtr>("node"); | ||
|
||
server_timeout_ = config().blackboard->template get<std::chrono::milliseconds>("server_timeout"); | ||
getInput<std::chrono::milliseconds>("server_timeout", server_timeout_); | ||
std::string image_topic; | ||
getInput<std::string>("image_topic", image_topic); | ||
image_sub_ = node_->create_subscription<sensor_msgs::msg::Image>( | ||
image_topic, 10, std::bind(&MLVicinityCondition::imageCallback, | ||
this, std::placeholders::_1)); | ||
} | ||
|
||
BT::NodeStatus MLVicinityCondition::tick() | ||
{ | ||
if (promptAIModel()) { | ||
return BT::NodeStatus::SUCCESS; | ||
} | ||
return BT::NodeStatus::FAILURE; | ||
} | ||
|
||
bool MLVicinityCondition::promptAIModel() | ||
{ | ||
// Get OpenAI key from environment variable | ||
const char * openai_key = std::getenv("OPENAI_API_KEY"); | ||
RCLCPP_INFO_STREAM(node_->get_logger(), openai_key); | ||
|
||
const std::string base_url = "https://api.openai.com/v1/chat/completions"; | ||
std::string response; | ||
CURL * curl = curl_easy_init(); | ||
|
||
const std::string prompt = "Are there any obstacles in the image which would prevent forward " | ||
"motion or possibly damage a robot as it moves forward? Note that a small doorsill is not a " | ||
"significant obstacle. Please answer in one word, yes or no."; | ||
|
||
// Lambda to handle service response | ||
auto write_callback = [&](void * contents, size_t size, size_t nmemb, std::string * response) | ||
{ | ||
size_t total_size = size * nmemb; | ||
response->append(reinterpret_cast<char *>(contents), total_size); | ||
return total_size; | ||
}; | ||
|
||
if (curl) { | ||
// See https://drqwrites.medium.com/accessing-the-openai-api-using-c-a3e527b6584b | ||
nlohmann::json request_data; | ||
request_data["model"] = "gpt-4o"; | ||
request_data["messages"][0]["role"] = "user"; | ||
// Attach the latest image | ||
std::string encoded_image = base64_encode(enc_msg, buf.size()); | ||
{ | ||
const std::lock_guard<std::mutex> lock(image_mutex_); | ||
// If we have no image yet | ||
if (!latest_image_) { | ||
return false; | ||
} | ||
std::vector<uchar> buf; | ||
cv::imencode(".jpg", img, buf); | ||
auto *enc_msg = reinterpret_cast<unsigned char*>(buf.data()); | ||
encoded_image = base64_encode(enc_msg, buf.size()); | ||
} | ||
request_data["messages"][0]["content"] = {prompt, encoded_image}; | ||
|
||
std::string request_data_str = request_data.dump().c_str(); | ||
|
||
struct curl_slist * headers = NULL; | ||
headers = curl_slist_append(headers, "Content-Type: application/json"); | ||
headers = curl_slist_append(headers, | ||
("Authorization: Bearer " + std::string(openai_key)).c_str()); | ||
curl_easy_setopt(curl, CURLOPT_URL, base_url.c_str()); | ||
curl_easy_setopt(curl, CURLOPT_POSTFIELDS, request_data_str.c_str()); | ||
curl_easy_setopt(curl, CURLOPT_POSTFIELDSIZE, request_data_str.length()); | ||
curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers); | ||
curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, write_callback); | ||
curl_easy_setopt(curl, CURLOPT_WRITEDATA, &response); | ||
CURLcode res = curl_easy_perform(curl); | ||
|
||
if (res != CURLE_OK) { | ||
RCLCPP_ERROR_STREAM(node_->get_logger(), "Curl request failed: " << curl_easy_strerror(res)); | ||
} | ||
|
||
curl_easy_cleanup(curl); | ||
curl_slist_free_all(headers); | ||
} | ||
|
||
nlohmann::json jresponse = nlohmann::json::parse(response); | ||
|
||
std::string string_response = jresponse["choices"][0]["message"]["content"].get<std::string>(); | ||
|
||
// Parse for a one-word Yes/No reply | ||
// Yes means there is an obstacle | ||
if ((string_response.find("Yes") != std::string::npos) || | ||
(string_response.find("yes") != std::string::npos)) | ||
{ | ||
return false; | ||
} | ||
return true; | ||
} | ||
} // namespace nav2_behavior_tree | ||
|
||
#include "behaviortree_cpp/bt_factory.h" | ||
BT_REGISTER_NODES(factory) | ||
{ | ||
factory.registerNodeType<nav2_behavior_tree::MLVicinityCondition>("MLVicinityCondition"); | ||
} |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i figured i'd cram everything related to ChatGPT into this one function so it can be swapped out easily, although it's not the most efficient way