forked from PickNikRobotics/moveit_pro_empty_ws
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Paul Gesel <[email protected]>
- Loading branch information
Showing
9 changed files
with
508 additions
and
182 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
#pragma once | ||
|
||
#include <future> | ||
#include <memory> | ||
#include <string> | ||
|
||
#include <moveit_pro_ml/onnx_sam2.hpp> | ||
#include <moveit_pro_ml/onnx_sam2_types.hpp> | ||
#include <moveit_studio_behavior_interface/async_behavior_base.hpp> | ||
#include <moveit_studio_vision_msgs/msg/mask2_d.hpp> | ||
#include <sensor_msgs/msg/image.hpp> | ||
#include <tl_expected/expected.hpp> | ||
|
||
|
||
namespace moveit_pro_ml | ||
{ | ||
class L2GModel; | ||
} | ||
|
||
namespace example_behaviors | ||
{ | ||
/** | ||
* @brief Segment an image using the SAM 2 model | ||
*/ | ||
class L2GBehavior : public moveit_studio::behaviors::AsyncBehaviorBase | ||
{ | ||
public: | ||
/** | ||
* @brief Constructor for the L2GBehavior behavior. | ||
* @param name The name of a particular instance of this Behavior. This will be set by the behavior tree factory when this Behavior is created within a new behavior tree. | ||
* @param config This contains runtime configuration info for this Behavior, such as the mapping between the Behavior's data ports on the behavior tree's blackboard. This will be set by the behavior tree factory when this Behavior is created within a new behavior tree. | ||
* @details An important limitation is that the members of the base Behavior class are not instantiated until after the initialize() function is called, so these classes should not be used within the constructor. | ||
*/ | ||
L2GBehavior(const std::string& name, const BT::NodeConfiguration& config, | ||
const std::shared_ptr<moveit_studio::behaviors::BehaviorContext>& shared_resources); | ||
|
||
/** | ||
* @brief Implementation of the required providedPorts() function for the Behavior. | ||
* @details The BehaviorTree.CPP library requires that Behaviors must implement a static function named providedPorts() which defines their input and output ports. If the Behavior does not use any ports, this function must return an empty BT::PortsList. | ||
* This function returns a list of ports with their names and port info, which is used internally by the behavior tree. | ||
* @return List of ports for the behavior. | ||
*/ | ||
static BT::PortsList providedPorts(); | ||
|
||
/** | ||
* @brief Implementation of the metadata() function for displaying metadata, such as Behavior description and | ||
* subcategory, in the MoveIt Studio Developer Tool. | ||
* @return A BT::KeyValueVector containing the Behavior metadata. | ||
*/ | ||
static BT::KeyValueVector metadata(); | ||
|
||
protected: | ||
tl::expected<bool, std::string> doWork() override; | ||
|
||
|
||
private: | ||
std::shared_ptr<moveit_pro_ml::L2GModel> l2g_; | ||
|
||
/** @brief Classes derived from AsyncBehaviorBase must implement getFuture() so that it returns a shared_future class member */ | ||
std::shared_future<tl::expected<bool, std::string>>& getFuture() override | ||
{ | ||
return future_; | ||
} | ||
|
||
/** @brief Classes derived from AsyncBehaviorBase must have this shared_future as a class member */ | ||
std::shared_future<tl::expected<bool, std::string>> future_; | ||
|
||
}; | ||
} // namespace sam2_segmentation |
Git LFS file not shown
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
#include <future> | ||
#include <memory> | ||
#include <string> | ||
|
||
#include <ament_index_cpp/get_package_share_directory.hpp> | ||
#include <example_behaviors/l2g.hpp> | ||
#include <geometry_msgs/msg/point_stamped.hpp> | ||
#include <geometry_msgs/msg/pose_stamped.hpp> | ||
#include <moveit_studio_behavior_interface/async_behavior_base.hpp> | ||
#include <moveit_studio_behavior_interface/get_required_ports.hpp> | ||
#include <moveit_studio_vision/pointcloud/point_cloud_tools.hpp> | ||
#include <moveit_studio_vision_msgs/msg/mask2_d.hpp> | ||
#include <sensor_msgs/msg/image.hpp> | ||
#include <sensor_msgs/msg/point_cloud2.hpp> | ||
#include <tl_expected/expected.hpp> | ||
|
||
#include <pcl_conversions/pcl_conversions.h> | ||
#include <pcl/point_cloud.h> | ||
#include <pcl/filters/filter.h> | ||
|
||
|
||
#include "moveit_pro_ml/onnx_model.hpp" | ||
|
||
namespace moveit_pro_ml | ||
{ | ||
class L2GModel | ||
{ | ||
public: | ||
L2GModel(const std::filesystem::path& onnx_file) | ||
: model{std::make_shared<ONNXTensorModel>(onnx_file)} | ||
{ | ||
} | ||
|
||
[[nodiscard]] std::vector<float> predict(const std::vector<std::array<float, 3>>& points) const | ||
{ | ||
// move image into tensor | ||
int num_points = points.size(); | ||
const auto point_data = model->create_dynamic_tensor<float>(model->dynamic_inputs.at("point_cloud"), | ||
{1, num_points, 3}); | ||
std::copy_n(points.data()->data(), num_points * 3, point_data); | ||
|
||
auto pred = model->predict_base(model->inputs, model->dynamic_inputs); | ||
|
||
// copy mask out and return reference | ||
auto shape = pred.at("predicted_grasps").onnx_shape; | ||
const auto predicted_grasps_data = model->get_tensor_data<float>(pred.at("predicted_grasps")); | ||
// const auto grasp_scores_data = model->get_tensor_data<float>(pred.at("grasp_scores")); | ||
|
||
return {predicted_grasps_data, predicted_grasps_data + get_size(shape)}; | ||
} | ||
|
||
std::shared_ptr<ONNXTensorModel> model; | ||
}; | ||
} // namespace moveit_pro_ml | ||
|
||
|
||
namespace | ||
{ | ||
constexpr auto kPortPointCloud = "point_cloud"; | ||
constexpr auto kPortPointCloudDefault = "{point_cloud}"; | ||
constexpr auto kPortGrasps = "grasps"; | ||
constexpr auto kPortGraspsDefault = "{grasps}"; | ||
} // namespace | ||
|
||
namespace example_behaviors | ||
{ | ||
L2GBehavior::L2GBehavior(const std::string& name, const BT::NodeConfiguration& config, | ||
const std::shared_ptr<moveit_studio::behaviors::BehaviorContext>& shared_resources) | ||
: moveit_studio::behaviors::AsyncBehaviorBase(name, config, shared_resources) | ||
{ | ||
const std::filesystem::path package_path = ament_index_cpp::get_package_share_directory("example_behaviors"); | ||
const std::filesystem::path onnx_file = package_path / "models" / "l2g.onnx"; | ||
l2g_ = std::make_unique<moveit_pro_ml::L2GModel>(onnx_file); | ||
} | ||
|
||
BT::PortsList L2GBehavior::providedPorts() | ||
{ | ||
return { | ||
BT::InputPort<sensor_msgs::msg::PointCloud2>(kPortPointCloud, kPortPointCloudDefault, | ||
"The Image to run segmentation on."), | ||
BT::OutputPort<std::vector<moveit_studio_vision_msgs::msg::Mask2D>>(kPortGrasps, kPortGraspsDefault, | ||
"The masks contained in a vector of <code>moveit_studio_vision_msgs::msg::Mask2D</code> messages.") | ||
}; | ||
} | ||
|
||
tl::expected<bool, std::string> L2GBehavior::doWork() | ||
{ | ||
const auto ports = moveit_studio::behaviors::getRequiredInputs( | ||
getInput<sensor_msgs::msg::PointCloud2>(kPortPointCloud)); | ||
|
||
// Check that all required input data ports were set. | ||
if (!ports.has_value()) | ||
{ | ||
auto error_message = fmt::format("Failed to get required values from input data ports:\n{}", ports.error()); | ||
return tl::make_unexpected(error_message); | ||
} | ||
const auto& [point_cloud_msg] = ports.value(); | ||
|
||
const auto cloud = std::make_shared<pcl::PointCloud<pcl::PointXYZ>>(); | ||
pcl::fromROSMsg(point_cloud_msg, *cloud); | ||
auto filtered_cloud = std::make_shared<pcl::PointCloud<pcl::PointXYZ>>(); | ||
pcl::Indices index; | ||
pcl::removeNaNFromPointCloud(*cloud, *filtered_cloud, index); | ||
shared_resources_->logger->publishWarnMessage("Pointcloud size: " + filtered_cloud->size()); | ||
double fraction = std::min(10000.0 / filtered_cloud->points.size(), 1.0); | ||
auto downsampled_cloud = moveit_studio::point_cloud_tools::downsampleRandom(filtered_cloud, fraction); | ||
shared_resources_->logger->publishWarnMessage("Downsampled pointcloud size: " + downsampled_cloud->size()); | ||
|
||
try | ||
{ | ||
std::vector<std::array<float, 3>> points; | ||
points.reserve(downsampled_cloud->points.size()); | ||
for (auto& point : downsampled_cloud->points) | ||
{ | ||
points.push_back({point.x, point.y, point.z}); | ||
} | ||
|
||
const auto grasps = l2g_->predict(points); | ||
std::vector<geometry_msgs::msg::PoseStamped> grasps_pose; | ||
for (size_t i = 0; i < grasps.size(); i += 7) | ||
{ | ||
geometry_msgs::msg::PoseStamped pose; | ||
pose.header = point_cloud_msg.header; | ||
pose.header.stamp = shared_resources_->node->now(); | ||
|
||
pose.pose.position.x = (grasps[i + 0] + grasps[i + 3]) / 2.0; | ||
pose.pose.position.y = (grasps[i + 1] + grasps[i + 4]) / 2.0; | ||
pose.pose.position.z = (grasps[i + 2] + grasps[i + 5]) / 2.0; | ||
grasps_pose.push_back(pose); | ||
|
||
if (i >= 70) | ||
{ | ||
break; | ||
} | ||
} | ||
|
||
setOutput<std::vector<geometry_msgs::msg::PoseStamped>>(kPortGrasps, grasps_pose); | ||
} | ||
catch (const std::invalid_argument& e) | ||
{ | ||
return tl::make_unexpected(fmt::format("Invalid argument: {}", e.what())); | ||
} | ||
|
||
return true; | ||
} | ||
|
||
BT::KeyValueVector L2GBehavior::metadata() | ||
{ | ||
return { | ||
{ | ||
"description", | ||
"Generate a grasp from a point cloud and output as a <code>geometry_msgs/PoseStamped</code> message." | ||
} | ||
}; | ||
} | ||
} // namespace sam2_segmentation |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
<?xml version="1.0" encoding="utf-8" ?> | ||
<root BTCPP_format="4" main_tree_to_execute="L2G"> | ||
<!--//////////--> | ||
<BehaviorTree | ||
ID="L2G" | ||
_description="Take a pointcloud snapshot of the scene with a depth camera" | ||
_favorite="true" | ||
_subtreeOnly="false" | ||
> | ||
<Control ID="Sequence" name="TopLevelSequence"> | ||
<!--Clear out old snapshot data--> | ||
<Action ID="ClearSnapshot" /> | ||
<Action ID="GetImage" topic_name="/wrist_camera/color" /> | ||
<Action | ||
ID="GetPointsFromUser" | ||
point_prompts="Select the object to be segmented;again;again" | ||
point_names="Point1;Point2;Point3" | ||
view_name="/wrist_camera/color" | ||
/> | ||
<Action ID="SAM2Segmentation" /> | ||
<Action ID="GetPointCloud" topic_name="/wrist_camera/points" /> | ||
<Action | ||
ID="GetCameraInfo" | ||
topic_name="/wrist_camera/camera_info" | ||
message_out="{camera_info}" | ||
timeout_sec="5.000000" | ||
/> | ||
<Action ID="GetMasks3DFromMasks2D" /> | ||
<Decorator ID="ForEachMask3D" vector_in="{masks3d}" out="{mask3d}"> | ||
<Action ID="GetPointCloudFromMask3D" point_cloud="{point_cloud}" /> | ||
</Decorator> | ||
<Action ID="SendPointCloudToUI" point_cloud="{point_cloud_fragment}" /> | ||
<Action ID="PublishPointCloud" point_cloud="{point_cloud_fragment}" /> | ||
<Action ID="L2GBehavior" point_cloud="{point_cloud_fragment}" /> | ||
<Decorator | ||
ID="ForEachPoseStamped" | ||
vector_in="{grasps}" | ||
out="{stamped_pose}" | ||
> | ||
<Control ID="Sequence"> | ||
<Action ID="VisualizePose" /> | ||
<Action ID="BreakpointSubscriber" /> | ||
</Control> | ||
</Decorator> | ||
</Control> | ||
</BehaviorTree> | ||
<TreeNodesModel> | ||
<SubTree ID="L2G" /> | ||
</TreeNodesModel> | ||
</root> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.