diff --git a/source/modulo_controllers/include/modulo_controllers/BaseControllerInterface.hpp b/source/modulo_controllers/include/modulo_controllers/BaseControllerInterface.hpp index 58e0fa26..8ea91d98 100644 --- a/source/modulo_controllers/include/modulo_controllers/BaseControllerInterface.hpp +++ b/source/modulo_controllers/include/modulo_controllers/BaseControllerInterface.hpp @@ -33,7 +33,7 @@ typedef std::variant< std::shared_ptr>, std::shared_ptr>, std::shared_ptr>, - std::shared_ptr>> + std::shared_ptr>, std::any> SubscriptionVariant; typedef std::variant< @@ -42,7 +42,7 @@ typedef std::variant< realtime_tools::RealtimeBuffer>, realtime_tools::RealtimeBuffer>, realtime_tools::RealtimeBuffer>, - realtime_tools::RealtimeBuffer>> + realtime_tools::RealtimeBuffer>, std::any> BufferVariant; typedef std::tuple< @@ -81,6 +81,7 @@ typedef std::variant< * @brief Input structure to save topic data in a realtime buffer and timestamps in one object. */ struct ControllerInput { + ControllerInput() = default; ControllerInput(BufferVariant buffer_variant) : buffer(std::move(buffer_variant)) {} BufferVariant buffer; std::chrono::time_point timestamp; @@ -479,6 +480,8 @@ class BaseControllerInterface : public controller_interface::ControllerInterface std::map> custom_output_configuration_callables_; + std::map> + custom_input_configuration_callables_; }; template @@ -523,11 +526,32 @@ inline void BaseControllerInterface::set_parameter_value(const std::string& name template inline void BaseControllerInterface::add_input(const std::string& name, const std::string& topic_name) { - auto buffer = realtime_tools::RealtimeBuffer>(); - auto input = ControllerInput(buffer); - create_input(input, name, topic_name); - input_message_pairs_.insert_or_assign( - name, modulo_core::communication::make_shared_message_pair(std::make_shared(), get_node()->get_clock())); + if constexpr (modulo_core::concepts::CustomT) { + auto buffer = realtime_tools::RealtimeBuffer(); + auto input = ControllerInput(buffer); + auto parsed_name = validate_and_declare_signal(name, "input", topic_name); + if (!parsed_name.empty()) { + inputs_.insert_or_assign(parsed_name, input); + custom_input_configuration_callables_.insert_or_assign( + name, [this](const std::string& name, const std::string& topic) { + auto subscription = + get_node()->create_subscription(topic, qos_, [this, name](const std::shared_ptr message) { + // TODO: should we try catch bad_any_cast in here? + auto buffer = + std::any_cast>>(inputs_.at(name).buffer); + buffer.writeFromNonRT(message); + inputs_.at(name).timestamp = std::chrono::steady_clock::now(); + }); + subscriptions_.push_back(subscription); + }); + } + } else { + auto buffer = realtime_tools::RealtimeBuffer>(); + auto input = ControllerInput(buffer); + create_input(input, name, topic_name); + input_message_pairs_.insert_or_assign( + name, modulo_core::communication::make_shared_message_pair(std::make_shared(), get_node()->get_clock())); + } } template<> @@ -626,36 +650,49 @@ inline void BaseControllerInterface::add_output(const std::string& template inline std::optional BaseControllerInterface::read_input(const std::string& name) { - if (!check_input_valid(name)) { - return {}; - } - auto message = - **std::get>>(inputs_.at(name).buffer) - .readFromNonRT(); - std::shared_ptr state; - try { - auto message_pair = input_message_pairs_.at(name); - message_pair->read(message); - state = message_pair->get_message_pair()->get_data(); - } catch (const std::exception& ex) { - RCLCPP_WARN_THROTTLE( - get_node()->get_logger(), *get_node()->get_clock(), 1000, - "Could not read EncodedState message on input '%s': %s", name.c_str(), ex.what()); - return {}; - } - if (state->is_empty()) { + if constexpr (modulo_core::concepts::CustomT) { + if (!check_input_valid(name)) { + return {}; + } + try { + auto buffer = std::any_cast>>(inputs_.at(name).buffer); + return **(buffer.readFromNonRT()); + } catch (const std::bad_any_cast& ex) { + RCLCPP_ERROR(get_node()->get_logger(), "Failed to read custom input: %s", ex.what()); + } return {}; - } - auto cast_ptr = std::dynamic_pointer_cast(state); - if (cast_ptr != nullptr) { - return *cast_ptr; } else { - RCLCPP_WARN_THROTTLE( - get_node()->get_logger(), *get_node()->get_clock(), 1000, - "Dynamic cast of message on input '%s' from type '%s' to type '%s' failed.", name.c_str(), - get_state_type_name(state->get_type()).c_str(), get_state_type_name(T().get_type()).c_str()); + if (!check_input_valid(name)) { + return {}; + } + auto message = + **std::get>>(inputs_.at(name).buffer) + .readFromNonRT(); + std::shared_ptr state; + try { + auto message_pair = input_message_pairs_.at(name); + message_pair->read(message); + state = message_pair->get_message_pair()->get_data(); + } catch (const std::exception& ex) { + RCLCPP_WARN_THROTTLE( + get_node()->get_logger(), *get_node()->get_clock(), 1000, + "Could not read EncodedState message on input '%s': %s", name.c_str(), ex.what()); + return {}; + } + if (state->is_empty()) { + return {}; + } + auto cast_ptr = std::dynamic_pointer_cast(state); + if (cast_ptr != nullptr) { + return *cast_ptr; + } else { + RCLCPP_WARN_THROTTLE( + get_node()->get_logger(), *get_node()->get_clock(), 1000, + "Dynamic cast of message on input '%s' from type '%s' to type '%s' failed.", name.c_str(), + get_state_type_name(state->get_type()).c_str(), get_state_type_name(T().get_type()).c_str()); + } + return {}; } - return {}; } template<> diff --git a/source/modulo_controllers/src/BaseControllerInterface.cpp b/source/modulo_controllers/src/BaseControllerInterface.cpp index 65b22327..42686fd3 100644 --- a/source/modulo_controllers/src/BaseControllerInterface.cpp +++ b/source/modulo_controllers/src/BaseControllerInterface.cpp @@ -330,7 +330,7 @@ void BaseControllerInterface::create_input( const ControllerInput& input, const std::string& name, const std::string& topic_name) { auto parsed_name = validate_and_declare_signal(name, "input", topic_name); if (!parsed_name.empty()) { - inputs_.insert_or_assign(name, input); + inputs_.insert_or_assign(parsed_name, input); } } @@ -357,6 +357,9 @@ void BaseControllerInterface::add_inputs() { }, [&](const realtime_tools::RealtimeBuffer>&) { subscriptions_.push_back(create_subscription(name, topic)); + }, + [&](const std::any&) { + custom_input_configuration_callables_.at(name)(name, topic); }}, input.buffer); } catch (const std::exception& ex) { diff --git a/source/modulo_controllers/test/test_controller_interface.cpp b/source/modulo_controllers/test/test_controller_interface.cpp index 540c0f11..f8219650 100644 --- a/source/modulo_controllers/test/test_controller_interface.cpp +++ b/source/modulo_controllers/test/test_controller_interface.cpp @@ -204,7 +204,24 @@ TYPED_TEST_P(ControllerInterfaceTest, CustomOutputTest) { interface->write_output("output", sensor_msgs::msg::Image()); } -REGISTER_TYPED_TEST_CASE_P(ControllerInterfaceTest, ConfigureErrorTest, InputTest, OutputTest, CustomOutputTest); +TYPED_TEST_P(ControllerInterfaceTest, CustomInputTest) { + auto interface = std::make_unique(); + interface->init("controller_interface", "", 0, "", interface->define_custom_node_options()); + interface->get_node()->set_parameter({"hardware_name", "test"}); + interface->get_node()->set_parameter({"input_validity_period", 0.1}); + + interface->add_input("input", "/input"); + auto node_state = interface->get_node()->configure(); + ASSERT_EQ(node_state.id(), lifecycle_msgs::msg::State::PRIMARY_STATE_INACTIVE); + + node_state = interface->get_node()->activate(); + ASSERT_EQ(node_state.id(), lifecycle_msgs::msg::State::PRIMARY_STATE_ACTIVE); + + auto msg = interface->read_input("input"); +} + +REGISTER_TYPED_TEST_CASE_P( + ControllerInterfaceTest, ConfigureErrorTest, InputTest, OutputTest, CustomOutputTest, CustomInputTest); typedef ::testing::Types SignalTypes; INSTANTIATE_TYPED_TEST_CASE_P(TestPrefix, ControllerInterfaceTest, SignalTypes);