Skip to content

Commit

Permalink
add model and behavior
Browse files Browse the repository at this point in the history
Signed-off-by: Paul Gesel <[email protected]>
  • Loading branch information
pac48 committed Dec 12, 2024
1 parent a657adf commit c17d4a6
Show file tree
Hide file tree
Showing 9 changed files with 508 additions and 182 deletions.
4 changes: 3 additions & 1 deletion src/example_behaviors/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ add_library(
src/example_setup_mtc_wave_hand.cpp
src/example_ndt_registration.cpp
src/example_ransac_registration.cpp
src/l2g.cpp
src/sam2_segmentation.cpp
src/register_behaviors.cpp)
target_include_directories(
example_behaviors
Expand All @@ -40,7 +42,7 @@ target_include_directories(
PRIVATE ${PCL_INCLUDE_DIRS})
ament_target_dependencies(example_behaviors
${THIS_PACKAGE_INCLUDE_DEPENDS})
target_link_libraries(example_behaviors onnx_sam2)
target_link_libraries(example_behaviors onnx_sam2 moveit_pro_ml)

# Install Libraries
install(
Expand Down
Empty file.
69 changes: 69 additions & 0 deletions src/example_behaviors/include/example_behaviors/l2g.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
#pragma once

#include <future>
#include <memory>
#include <string>

#include <moveit_pro_ml/onnx_sam2.hpp>
#include <moveit_pro_ml/onnx_sam2_types.hpp>
#include <moveit_studio_behavior_interface/async_behavior_base.hpp>
#include <moveit_studio_vision_msgs/msg/mask2_d.hpp>
#include <sensor_msgs/msg/image.hpp>
#include <tl_expected/expected.hpp>


namespace moveit_pro_ml
{
class L2GModel;
}

namespace example_behaviors
{
/**
* @brief Segment an image using the SAM 2 model
*/
class L2GBehavior : public moveit_studio::behaviors::AsyncBehaviorBase
{
public:
/**
* @brief Constructor for the L2GBehavior behavior.
* @param name The name of a particular instance of this Behavior. This will be set by the behavior tree factory when this Behavior is created within a new behavior tree.
* @param config This contains runtime configuration info for this Behavior, such as the mapping between the Behavior's data ports on the behavior tree's blackboard. This will be set by the behavior tree factory when this Behavior is created within a new behavior tree.
* @details An important limitation is that the members of the base Behavior class are not instantiated until after the initialize() function is called, so these classes should not be used within the constructor.
*/
L2GBehavior(const std::string& name, const BT::NodeConfiguration& config,
const std::shared_ptr<moveit_studio::behaviors::BehaviorContext>& shared_resources);

/**
* @brief Implementation of the required providedPorts() function for the Behavior.
* @details The BehaviorTree.CPP library requires that Behaviors must implement a static function named providedPorts() which defines their input and output ports. If the Behavior does not use any ports, this function must return an empty BT::PortsList.
* This function returns a list of ports with their names and port info, which is used internally by the behavior tree.
* @return List of ports for the behavior.
*/
static BT::PortsList providedPorts();

/**
* @brief Implementation of the metadata() function for displaying metadata, such as Behavior description and
* subcategory, in the MoveIt Studio Developer Tool.
* @return A BT::KeyValueVector containing the Behavior metadata.
*/
static BT::KeyValueVector metadata();

protected:
tl::expected<bool, std::string> doWork() override;


private:
std::shared_ptr<moveit_pro_ml::L2GModel> l2g_;

/** @brief Classes derived from AsyncBehaviorBase must implement getFuture() so that it returns a shared_future class member */
std::shared_future<tl::expected<bool, std::string>>& getFuture() override
{
return future_;
}

/** @brief Classes derived from AsyncBehaviorBase must have this shared_future as a class member */
std::shared_future<tl::expected<bool, std::string>> future_;

};
} // namespace sam2_segmentation
3 changes: 3 additions & 0 deletions src/example_behaviors/models/l2g.onnx
Git LFS file not shown
156 changes: 156 additions & 0 deletions src/example_behaviors/src/l2g.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
#include <future>
#include <memory>
#include <string>

#include <ament_index_cpp/get_package_share_directory.hpp>
#include <example_behaviors/l2g.hpp>
#include <geometry_msgs/msg/point_stamped.hpp>
#include <geometry_msgs/msg/pose_stamped.hpp>
#include <moveit_studio_behavior_interface/async_behavior_base.hpp>
#include <moveit_studio_behavior_interface/get_required_ports.hpp>
#include <moveit_studio_vision/pointcloud/point_cloud_tools.hpp>
#include <moveit_studio_vision_msgs/msg/mask2_d.hpp>
#include <sensor_msgs/msg/image.hpp>
#include <sensor_msgs/msg/point_cloud2.hpp>
#include <tl_expected/expected.hpp>

#include <pcl_conversions/pcl_conversions.h>
#include <pcl/point_cloud.h>
#include <pcl/filters/filter.h>


#include "moveit_pro_ml/onnx_model.hpp"

namespace moveit_pro_ml
{
class L2GModel
{
public:
L2GModel(const std::filesystem::path& onnx_file)
: model{std::make_shared<ONNXTensorModel>(onnx_file)}
{
}

[[nodiscard]] std::vector<float> predict(const std::vector<std::array<float, 3>>& points) const
{
// move image into tensor
int num_points = points.size();
const auto point_data = model->create_dynamic_tensor<float>(model->dynamic_inputs.at("point_cloud"),
{1, num_points, 3});
std::copy_n(points.data()->data(), num_points * 3, point_data);

auto pred = model->predict_base(model->inputs, model->dynamic_inputs);

// copy mask out and return reference
auto shape = pred.at("predicted_grasps").onnx_shape;
const auto predicted_grasps_data = model->get_tensor_data<float>(pred.at("predicted_grasps"));
// const auto grasp_scores_data = model->get_tensor_data<float>(pred.at("grasp_scores"));

return {predicted_grasps_data, predicted_grasps_data + get_size(shape)};
}

std::shared_ptr<ONNXTensorModel> model;
};
} // namespace moveit_pro_ml


namespace
{
constexpr auto kPortPointCloud = "point_cloud";
constexpr auto kPortPointCloudDefault = "{point_cloud}";
constexpr auto kPortGrasps = "grasps";
constexpr auto kPortGraspsDefault = "{grasps}";
} // namespace

namespace example_behaviors
{
L2GBehavior::L2GBehavior(const std::string& name, const BT::NodeConfiguration& config,
const std::shared_ptr<moveit_studio::behaviors::BehaviorContext>& shared_resources)
: moveit_studio::behaviors::AsyncBehaviorBase(name, config, shared_resources)
{
const std::filesystem::path package_path = ament_index_cpp::get_package_share_directory("example_behaviors");
const std::filesystem::path onnx_file = package_path / "models" / "l2g.onnx";
l2g_ = std::make_unique<moveit_pro_ml::L2GModel>(onnx_file);
}

BT::PortsList L2GBehavior::providedPorts()
{
return {
BT::InputPort<sensor_msgs::msg::PointCloud2>(kPortPointCloud, kPortPointCloudDefault,
"The Image to run segmentation on."),
BT::OutputPort<std::vector<moveit_studio_vision_msgs::msg::Mask2D>>(kPortGrasps, kPortGraspsDefault,
"The masks contained in a vector of <code>moveit_studio_vision_msgs::msg::Mask2D</code> messages.")
};
}

tl::expected<bool, std::string> L2GBehavior::doWork()
{
const auto ports = moveit_studio::behaviors::getRequiredInputs(
getInput<sensor_msgs::msg::PointCloud2>(kPortPointCloud));

// Check that all required input data ports were set.
if (!ports.has_value())
{
auto error_message = fmt::format("Failed to get required values from input data ports:\n{}", ports.error());
return tl::make_unexpected(error_message);
}
const auto& [point_cloud_msg] = ports.value();

const auto cloud = std::make_shared<pcl::PointCloud<pcl::PointXYZ>>();
pcl::fromROSMsg(point_cloud_msg, *cloud);
auto filtered_cloud = std::make_shared<pcl::PointCloud<pcl::PointXYZ>>();
pcl::Indices index;
pcl::removeNaNFromPointCloud(*cloud, *filtered_cloud, index);
shared_resources_->logger->publishWarnMessage("Pointcloud size: " + filtered_cloud->size());
double fraction = std::min(10000.0 / filtered_cloud->points.size(), 1.0);
auto downsampled_cloud = moveit_studio::point_cloud_tools::downsampleRandom(filtered_cloud, fraction);
shared_resources_->logger->publishWarnMessage("Downsampled pointcloud size: " + downsampled_cloud->size());

try
{
std::vector<std::array<float, 3>> points;
points.reserve(downsampled_cloud->points.size());
for (auto& point : downsampled_cloud->points)
{
points.push_back({point.x, point.y, point.z});
}

const auto grasps = l2g_->predict(points);
std::vector<geometry_msgs::msg::PoseStamped> grasps_pose;
for (size_t i = 0; i < grasps.size(); i += 7)
{
geometry_msgs::msg::PoseStamped pose;
pose.header = point_cloud_msg.header;
pose.header.stamp = shared_resources_->node->now();

pose.pose.position.x = (grasps[i + 0] + grasps[i + 3]) / 2.0;
pose.pose.position.y = (grasps[i + 1] + grasps[i + 4]) / 2.0;
pose.pose.position.z = (grasps[i + 2] + grasps[i + 5]) / 2.0;
grasps_pose.push_back(pose);

if (i >= 70)
{
break;
}
}

setOutput<std::vector<geometry_msgs::msg::PoseStamped>>(kPortGrasps, grasps_pose);
}
catch (const std::invalid_argument& e)
{
return tl::make_unexpected(fmt::format("Invalid argument: {}", e.what()));
}

return true;
}

BT::KeyValueVector L2GBehavior::metadata()
{
return {
{
"description",
"Generate a grasp from a point cloud and output as a <code>geometry_msgs/PoseStamped</code> message."
}
};
}
} // namespace sam2_segmentation
4 changes: 4 additions & 0 deletions src/example_behaviors/src/register_behaviors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
#include <example_behaviors/example_setup_mtc_place_from_pose.hpp>
#include <example_behaviors/example_ndt_registration.hpp>
#include <example_behaviors/example_ransac_registration.hpp>
#include <example_behaviors/l2g.hpp>
#include <example_behaviors/sam2_segmentation.hpp>

#include <pluginlib/class_list_macros.hpp>

Expand All @@ -40,6 +42,8 @@ class ExampleBehaviorsLoader : public moveit_studio::behaviors::SharedResourcesN
shared_resources);
moveit_studio::behaviors::registerBehavior<ExampleNDTRegistration>(factory, "ExampleNDTRegistration", shared_resources);
moveit_studio::behaviors::registerBehavior<ExampleRANSACRegistration>(factory, "ExampleRANSACRegistration", shared_resources);
moveit_studio::behaviors::registerBehavior<L2GBehavior>(factory, "L2GBehavior", shared_resources);
moveit_studio::behaviors::registerBehavior<SAM2Segmentation>(factory, "SAM2Segmentation", shared_resources);
}
};
} // namespace example_behaviors
Expand Down
50 changes: 50 additions & 0 deletions src/lab_sim/objectives/l2g.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
<?xml version="1.0" encoding="utf-8" ?>
<root BTCPP_format="4" main_tree_to_execute="L2G">
<!--//////////-->
<BehaviorTree
ID="L2G"
_description="Take a pointcloud snapshot of the scene with a depth camera"
_favorite="true"
_subtreeOnly="false"
>
<Control ID="Sequence" name="TopLevelSequence">
<!--Clear out old snapshot data-->
<Action ID="ClearSnapshot" />
<Action ID="GetImage" topic_name="/wrist_camera/color" />
<Action
ID="GetPointsFromUser"
point_prompts="Select the object to be segmented;again;again"
point_names="Point1;Point2;Point3"
view_name="/wrist_camera/color"
/>
<Action ID="SAM2Segmentation" />
<Action ID="GetPointCloud" topic_name="/wrist_camera/points" />
<Action
ID="GetCameraInfo"
topic_name="/wrist_camera/camera_info"
message_out="{camera_info}"
timeout_sec="5.000000"
/>
<Action ID="GetMasks3DFromMasks2D" />
<Decorator ID="ForEachMask3D" vector_in="{masks3d}" out="{mask3d}">
<Action ID="GetPointCloudFromMask3D" point_cloud="{point_cloud}" />
</Decorator>
<Action ID="SendPointCloudToUI" point_cloud="{point_cloud_fragment}" />
<Action ID="PublishPointCloud" point_cloud="{point_cloud_fragment}" />
<Action ID="L2GBehavior" point_cloud="{point_cloud_fragment}" />
<Decorator
ID="ForEachPoseStamped"
vector_in="{grasps}"
out="{stamped_pose}"
>
<Control ID="Sequence">
<Action ID="VisualizePose" />
<Action ID="BreakpointSubscriber" />
</Control>
</Decorator>
</Control>
</BehaviorTree>
<TreeNodesModel>
<SubTree ID="L2G" />
</TreeNodesModel>
</root>
3 changes: 1 addition & 2 deletions src/lab_sim/objectives/run_sam2_onnx.xml
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
<Control ID="Sequence">
<Action ID="ClearSnapshot" />
<Action ID="GetImage" topic_name="/wrist_camera/color" />
<Action
ID="GetPointsFromUser"
<Action ID="GetPointsFromUser"
point_prompts="Select the object to be segmented;"
point_names="Point1;"
view_name="/wrist_camera/color"
Expand Down
Loading

0 comments on commit c17d4a6

Please sign in to comment.