Skip to content

Commit

Permalink
feat: support generic inputs in controllers
Browse files Browse the repository at this point in the history
  • Loading branch information
bpapaspyros committed Oct 3, 2024
1 parent 49db694 commit 643c6a1
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ typedef std::variant<
std::shared_ptr<rclcpp::Subscription<std_msgs::msg::Float64>>,
std::shared_ptr<rclcpp::Subscription<std_msgs::msg::Float64MultiArray>>,
std::shared_ptr<rclcpp::Subscription<std_msgs::msg::Int32>>,
std::shared_ptr<rclcpp::Subscription<std_msgs::msg::String>>>
std::shared_ptr<rclcpp::Subscription<std_msgs::msg::String>>, std::any>
SubscriptionVariant;

typedef std::variant<
Expand All @@ -42,7 +42,7 @@ typedef std::variant<
realtime_tools::RealtimeBuffer<std::shared_ptr<std_msgs::msg::Float64>>,
realtime_tools::RealtimeBuffer<std::shared_ptr<std_msgs::msg::Float64MultiArray>>,
realtime_tools::RealtimeBuffer<std::shared_ptr<std_msgs::msg::Int32>>,
realtime_tools::RealtimeBuffer<std::shared_ptr<std_msgs::msg::String>>>
realtime_tools::RealtimeBuffer<std::shared_ptr<std_msgs::msg::String>>, std::any>
BufferVariant;

typedef std::tuple<
Expand Down Expand Up @@ -81,6 +81,7 @@ typedef std::variant<
* @brief Input structure to save topic data in a realtime buffer and timestamps in one object.
*/
struct ControllerInput {
ControllerInput() = default;
ControllerInput(BufferVariant buffer_variant) : buffer(std::move(buffer_variant)) {}
BufferVariant buffer;
std::chrono::time_point<std::chrono::steady_clock> timestamp;
Expand Down Expand Up @@ -479,6 +480,8 @@ class BaseControllerInterface : public controller_interface::ControllerInterface

std::map<std::string, std::function<void(CustomPublishers&, const std::string&)>>
custom_output_configuration_callables_;
std::map<std::string, std::function<void(const std::string&, const std::string&)>>
custom_input_configuration_callables_;
};

template<typename T>
Expand Down Expand Up @@ -523,11 +526,32 @@ inline void BaseControllerInterface::set_parameter_value(const std::string& name

template<typename T>
inline void BaseControllerInterface::add_input(const std::string& name, const std::string& topic_name) {
auto buffer = realtime_tools::RealtimeBuffer<std::shared_ptr<modulo_core::EncodedState>>();
auto input = ControllerInput(buffer);
create_input(input, name, topic_name);
input_message_pairs_.insert_or_assign(
name, modulo_core::communication::make_shared_message_pair(std::make_shared<T>(), get_node()->get_clock()));
if constexpr (modulo_core::concepts::CustomT<T>) {
auto buffer = realtime_tools::RealtimeBuffer<T>();
auto input = ControllerInput(buffer);
auto parsed_name = validate_and_declare_signal(name, "input", topic_name);
if (!parsed_name.empty()) {
inputs_.insert_or_assign(parsed_name, input);
custom_input_configuration_callables_.insert_or_assign(
name, [this](const std::string& name, const std::string& topic) {
auto subscription =
get_node()->create_subscription<T>(topic, qos_, [this, name](const std::shared_ptr<T> message) {
// TODO: should we try catch bad_any_cast in here?
auto buffer =
std::any_cast<realtime_tools::RealtimeBuffer<std::shared_ptr<T>>>(inputs_.at(name).buffer);
buffer.writeFromNonRT(message);
inputs_.at(name).timestamp = std::chrono::steady_clock::now();
});
subscriptions_.push_back(subscription);
});
}
} else {
auto buffer = realtime_tools::RealtimeBuffer<std::shared_ptr<modulo_core::EncodedState>>();
auto input = ControllerInput(buffer);
create_input(input, name, topic_name);
input_message_pairs_.insert_or_assign(
name, modulo_core::communication::make_shared_message_pair(std::make_shared<T>(), get_node()->get_clock()));
}
}

template<>
Expand Down Expand Up @@ -626,36 +650,49 @@ inline void BaseControllerInterface::add_output<std::string>(const std::string&

template<typename T>
inline std::optional<T> BaseControllerInterface::read_input(const std::string& name) {
if (!check_input_valid(name)) {
return {};
}
auto message =
**std::get<realtime_tools::RealtimeBuffer<std::shared_ptr<modulo_core::EncodedState>>>(inputs_.at(name).buffer)
.readFromNonRT();
std::shared_ptr<state_representation::State> state;
try {
auto message_pair = input_message_pairs_.at(name);
message_pair->read<modulo_core::EncodedState, state_representation::State>(message);
state = message_pair->get_message_pair<modulo_core::EncodedState, state_representation::State>()->get_data();
} catch (const std::exception& ex) {
RCLCPP_WARN_THROTTLE(
get_node()->get_logger(), *get_node()->get_clock(), 1000,
"Could not read EncodedState message on input '%s': %s", name.c_str(), ex.what());
return {};
}
if (state->is_empty()) {
if constexpr (modulo_core::concepts::CustomT<T>) {
if (!check_input_valid(name)) {
return {};
}
try {
auto buffer = std::any_cast<realtime_tools::RealtimeBuffer<std::shared_ptr<T>>>(inputs_.at(name).buffer);
return **(buffer.readFromNonRT());
} catch (const std::bad_any_cast& ex) {
RCLCPP_ERROR(get_node()->get_logger(), "Failed to read custom input: %s", ex.what());
}
return {};
}
auto cast_ptr = std::dynamic_pointer_cast<T>(state);
if (cast_ptr != nullptr) {
return *cast_ptr;
} else {
RCLCPP_WARN_THROTTLE(
get_node()->get_logger(), *get_node()->get_clock(), 1000,
"Dynamic cast of message on input '%s' from type '%s' to type '%s' failed.", name.c_str(),
get_state_type_name(state->get_type()).c_str(), get_state_type_name(T().get_type()).c_str());
if (!check_input_valid(name)) {
return {};
}
auto message =
**std::get<realtime_tools::RealtimeBuffer<std::shared_ptr<modulo_core::EncodedState>>>(inputs_.at(name).buffer)
.readFromNonRT();
std::shared_ptr<state_representation::State> state;
try {
auto message_pair = input_message_pairs_.at(name);
message_pair->read<modulo_core::EncodedState, state_representation::State>(message);
state = message_pair->get_message_pair<modulo_core::EncodedState, state_representation::State>()->get_data();
} catch (const std::exception& ex) {
RCLCPP_WARN_THROTTLE(
get_node()->get_logger(), *get_node()->get_clock(), 1000,
"Could not read EncodedState message on input '%s': %s", name.c_str(), ex.what());
return {};
}
if (state->is_empty()) {
return {};
}
auto cast_ptr = std::dynamic_pointer_cast<T>(state);
if (cast_ptr != nullptr) {
return *cast_ptr;
} else {
RCLCPP_WARN_THROTTLE(
get_node()->get_logger(), *get_node()->get_clock(), 1000,
"Dynamic cast of message on input '%s' from type '%s' to type '%s' failed.", name.c_str(),
get_state_type_name(state->get_type()).c_str(), get_state_type_name(T().get_type()).c_str());
}
return {};
}
return {};
}

template<>
Expand Down
5 changes: 4 additions & 1 deletion source/modulo_controllers/src/BaseControllerInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ void BaseControllerInterface::create_input(
const ControllerInput& input, const std::string& name, const std::string& topic_name) {
auto parsed_name = validate_and_declare_signal(name, "input", topic_name);
if (!parsed_name.empty()) {
inputs_.insert_or_assign(name, input);
inputs_.insert_or_assign(parsed_name, input);
}
}

Expand All @@ -357,6 +357,9 @@ void BaseControllerInterface::add_inputs() {
},
[&](const realtime_tools::RealtimeBuffer<std::shared_ptr<std_msgs::msg::String>>&) {
subscriptions_.push_back(create_subscription<std_msgs::msg::String>(name, topic));
},
[&](const std::any&) {
custom_input_configuration_callables_.at(name)(name, topic);
}},
input.buffer);
} catch (const std::exception& ex) {
Expand Down
19 changes: 18 additions & 1 deletion source/modulo_controllers/test/test_controller_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,24 @@ TYPED_TEST_P(ControllerInterfaceTest, CustomOutputTest) {
interface->write_output<sensor_msgs::msg::Image>("output", sensor_msgs::msg::Image());
}

REGISTER_TYPED_TEST_CASE_P(ControllerInterfaceTest, ConfigureErrorTest, InputTest, OutputTest, CustomOutputTest);
TYPED_TEST_P(ControllerInterfaceTest, CustomInputTest) {
auto interface = std::make_unique<FriendControllerInterface>();
interface->init("controller_interface", "", 0, "", interface->define_custom_node_options());
interface->get_node()->set_parameter({"hardware_name", "test"});
interface->get_node()->set_parameter({"input_validity_period", 0.1});

interface->add_input<sensor_msgs::msg::Image>("input", "/input");
auto node_state = interface->get_node()->configure();
ASSERT_EQ(node_state.id(), lifecycle_msgs::msg::State::PRIMARY_STATE_INACTIVE);

node_state = interface->get_node()->activate();
ASSERT_EQ(node_state.id(), lifecycle_msgs::msg::State::PRIMARY_STATE_ACTIVE);

auto msg = interface->read_input<sensor_msgs::msg::Image>("input");
}

REGISTER_TYPED_TEST_CASE_P(
ControllerInterfaceTest, ConfigureErrorTest, InputTest, OutputTest, CustomOutputTest, CustomInputTest);

typedef ::testing::Types<BoolT, DoubleT, DoubleVecT, IntT, StringT, CartesianStateT, JointStateT> SignalTypes;
INSTANTIATE_TYPED_TEST_CASE_P(TestPrefix, ControllerInterfaceTest, SignalTypes);

0 comments on commit 643c6a1

Please sign in to comment.