diff --git a/source/modulo_controllers/include/modulo_controllers/BaseControllerInterface.hpp b/source/modulo_controllers/include/modulo_controllers/BaseControllerInterface.hpp index a43f2526..58e0fa26 100644 --- a/source/modulo_controllers/include/modulo_controllers/BaseControllerInterface.hpp +++ b/source/modulo_controllers/include/modulo_controllers/BaseControllerInterface.hpp @@ -714,44 +714,84 @@ inline std::optional BaseControllerInterface::read_input inline void BaseControllerInterface::write_output(const std::string& name, const T& data) { - if (data.is_empty()) { - RCLCPP_DEBUG_THROTTLE( - get_node()->get_logger(), *get_node()->get_clock(), 1000, - "Skipping publication of output '%s' due to emptiness of state", name.c_str()); - return; - } - if (outputs_.find(name) == outputs_.end()) { - RCLCPP_WARN_THROTTLE( - get_node()->get_logger(), *get_node()->get_clock(), 1000, "Could not find output '%s'", name.c_str()); - return; - } - EncodedStatePublishers publishers; - try { - publishers = std::get(outputs_.at(name)); - } catch (const std::bad_variant_access&) { - RCLCPP_WARN_THROTTLE( - get_node()->get_logger(), *get_node()->get_clock(), 1000, - "Could not retrieve publisher for output '%s': Invalid output type", name.c_str()); - return; - } - if (const auto output_type = std::get<0>(publishers)->get_type(); output_type != data.get_type()) { - RCLCPP_WARN_THROTTLE( - get_node()->get_logger(), *get_node()->get_clock(), 1000, - "Skipping publication of output '%s' due to wrong data type (expected '%s', got '%s')", - state_representation::get_state_type_name(output_type).c_str(), - state_representation::get_state_type_name(data.get_type()).c_str(), name.c_str()); - return; - } - auto rt_pub = std::get<2>(publishers); - if (rt_pub && rt_pub->trylock()) { + if constexpr (modulo_core::concepts::CustomT) { + if (outputs_.find(name) == outputs_.end()) { + RCLCPP_WARN_THROTTLE( + get_node()->get_logger(), *get_node()->get_clock(), 1000, "Could not find output '%s'", name.c_str()); + return; + } + + CustomPublishers publishers; + try { + publishers = std::get(outputs_.at(name)); + } catch (const std::bad_variant_access&) { + RCLCPP_WARN_THROTTLE( + get_node()->get_logger(), *get_node()->get_clock(), 1000, + "Could not retrieve publisher for output '%s': Invalid output type", name.c_str()); + return; + } + + std::shared_ptr> pub; + realtime_tools::RealtimePublisherSharedPtr rt_pub; try { - modulo_core::translators::write_message(rt_pub->msg_, data, get_node()->get_clock()->now()); - } catch (const modulo_core::exceptions::MessageTranslationException& ex) { + pub = std::any_cast>>(publishers.first); + rt_pub = std::any_cast>(publishers.second); + } catch (const std::bad_any_cast& ex) { RCLCPP_ERROR_THROTTLE( - get_node()->get_logger(), *get_node()->get_clock(), 1000, "Failed to publish output '%s': %s", name.c_str(), - ex.what()); + get_node()->get_logger(), *get_node()->get_clock(), 1000, + "Skipping publication of output '%s' due to wrong data type: %s", name.c_str(), ex.what()); + } + + if (rt_pub && rt_pub->trylock()) { + try { + rt_pub->msg_ = data; + } catch (const modulo_core::exceptions::MessageTranslationException& ex) { + RCLCPP_ERROR_THROTTLE( + get_node()->get_logger(), *get_node()->get_clock(), 1000, "Failed to publish output '%s': %s", name.c_str(), + ex.what()); + } + rt_pub->unlockAndPublish(); + } + } else { + if (data.is_empty()) { + RCLCPP_DEBUG_THROTTLE( + get_node()->get_logger(), *get_node()->get_clock(), 1000, + "Skipping publication of output '%s' due to emptiness of state", name.c_str()); + return; + } + if (outputs_.find(name) == outputs_.end()) { + RCLCPP_WARN_THROTTLE( + get_node()->get_logger(), *get_node()->get_clock(), 1000, "Could not find output '%s'", name.c_str()); + return; + } + EncodedStatePublishers publishers; + try { + publishers = std::get(outputs_.at(name)); + } catch (const std::bad_variant_access&) { + RCLCPP_WARN_THROTTLE( + get_node()->get_logger(), *get_node()->get_clock(), 1000, + "Could not retrieve publisher for output '%s': Invalid output type", name.c_str()); + return; + } + if (const auto output_type = std::get<0>(publishers)->get_type(); output_type != data.get_type()) { + RCLCPP_WARN_THROTTLE( + get_node()->get_logger(), *get_node()->get_clock(), 1000, + "Skipping publication of output '%s' due to wrong data type (expected '%s', got '%s')", + state_representation::get_state_type_name(output_type).c_str(), + state_representation::get_state_type_name(data.get_type()).c_str(), name.c_str()); + return; + } + auto rt_pub = std::get<2>(publishers); + if (rt_pub && rt_pub->trylock()) { + try { + modulo_core::translators::write_message(rt_pub->msg_, data, get_node()->get_clock()->now()); + } catch (const modulo_core::exceptions::MessageTranslationException& ex) { + RCLCPP_ERROR_THROTTLE( + get_node()->get_logger(), *get_node()->get_clock(), 1000, "Failed to publish output '%s': %s", name.c_str(), + ex.what()); + } + rt_pub->unlockAndPublish(); } - rt_pub->unlockAndPublish(); } } diff --git a/source/modulo_controllers/test/test_controller_interface.cpp b/source/modulo_controllers/test/test_controller_interface.cpp index 7917aecb..540c0f11 100644 --- a/source/modulo_controllers/test/test_controller_interface.cpp +++ b/source/modulo_controllers/test/test_controller_interface.cpp @@ -198,22 +198,10 @@ TYPED_TEST_P(ControllerInterfaceTest, CustomOutputTest) { auto node_state = interface->get_node()->configure(); ASSERT_EQ(node_state.id(), lifecycle_msgs::msg::State::PRIMARY_STATE_INACTIVE); - // rclcpp::Node test_node("test_node"); - // auto publisher = test_node.create_publisher("/input", rclcpp::SystemDefaultsQoS()); - - // for (auto [message_data, write_func, read_func, validation_func] : this->test_cases_) { - // message_data = write_func(message_data); - // auto message = std::get<1>(message_data); - // publisher->publish(message); - // rclcpp::spin_some(this->interface_->get_node()->get_node_base_interface()); - // auto input = this->interface_->template read_input("input"); - // ASSERT_TRUE(input); - // EXPECT_TRUE(validation_func(message_data, std::make_tuple(*input, message))); - // std::this_thread::sleep_for(100ms); - // ASSERT_FALSE(this->interface_->template read_input("input")); - // } - - EXPECT_TRUE(true); + node_state = interface->get_node()->activate(); + ASSERT_EQ(node_state.id(), lifecycle_msgs::msg::State::PRIMARY_STATE_ACTIVE); + + interface->write_output("output", sensor_msgs::msg::Image()); } REGISTER_TYPED_TEST_CASE_P(ControllerInterfaceTest, ConfigureErrorTest, InputTest, OutputTest, CustomOutputTest);