Skip to content

Commit

Permalink
feat!: support generic inputs and outputs in controllers (#150)
Browse files Browse the repository at this point in the history
  • Loading branch information
bpapaspyros authored and domire8 committed Oct 4, 2024
1 parent 3b3294c commit aa22a8c
Show file tree
Hide file tree
Showing 3 changed files with 193 additions and 71 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include <any>
#include <mutex>

#include <controller_interface/controller_interface.hpp>
Expand All @@ -22,6 +23,8 @@

#include <modulo_utils/parsing.hpp>

#include <modulo_core/concepts.hpp>

namespace modulo_controllers {

typedef std::variant<
Expand All @@ -30,7 +33,7 @@ typedef std::variant<
std::shared_ptr<rclcpp::Subscription<std_msgs::msg::Float64>>,
std::shared_ptr<rclcpp::Subscription<std_msgs::msg::Float64MultiArray>>,
std::shared_ptr<rclcpp::Subscription<std_msgs::msg::Int32>>,
std::shared_ptr<rclcpp::Subscription<std_msgs::msg::String>>>
std::shared_ptr<rclcpp::Subscription<std_msgs::msg::String>>, std::any>
SubscriptionVariant;

typedef std::variant<
Expand All @@ -39,7 +42,7 @@ typedef std::variant<
realtime_tools::RealtimeBuffer<std::shared_ptr<std_msgs::msg::Float64>>,
realtime_tools::RealtimeBuffer<std::shared_ptr<std_msgs::msg::Float64MultiArray>>,
realtime_tools::RealtimeBuffer<std::shared_ptr<std_msgs::msg::Int32>>,
realtime_tools::RealtimeBuffer<std::shared_ptr<std_msgs::msg::String>>>
realtime_tools::RealtimeBuffer<std::shared_ptr<std_msgs::msg::String>>, std::any>
BufferVariant;

typedef std::tuple<
Expand All @@ -66,16 +69,19 @@ typedef std::pair<
std::shared_ptr<rclcpp::Publisher<std_msgs::msg::String>>,
realtime_tools::RealtimePublisherSharedPtr<std_msgs::msg::String>>
StringPublishers;
typedef std::pair<std::any, std::any> CustomPublishers;

typedef std::variant<
EncodedStatePublishers, BoolPublishers, DoublePublishers, DoubleVecPublishers, IntPublishers, StringPublishers>
EncodedStatePublishers, BoolPublishers, DoublePublishers, DoubleVecPublishers, IntPublishers, StringPublishers,
CustomPublishers>
PublisherVariant;

/**
* @struct ControllerInput
* @brief Input structure to save topic data in a realtime buffer and timestamps in one object.
*/
struct ControllerInput {
ControllerInput() = default;
ControllerInput(BufferVariant buffer_variant) : buffer(std::move(buffer_variant)) {}
BufferVariant buffer;
std::chrono::time_point<std::chrono::steady_clock> timestamp;
Expand Down Expand Up @@ -471,6 +477,11 @@ class BaseControllerInterface : public controller_interface::ControllerInterface
std::shared_ptr<rclcpp::TimerBase> predicate_timer_;

std::timed_mutex command_mutex_;

std::map<std::string, std::function<void(CustomPublishers&, const std::string&)>>
custom_output_configuration_callables_;
std::map<std::string, std::function<void(const std::string&, const std::string&)>>
custom_input_configuration_callables_;
};

template<typename T>
Expand Down Expand Up @@ -515,11 +526,36 @@ inline void BaseControllerInterface::set_parameter_value(const std::string& name

template<typename T>
inline void BaseControllerInterface::add_input(const std::string& name, const std::string& topic_name) {
auto buffer = realtime_tools::RealtimeBuffer<std::shared_ptr<modulo_core::EncodedState>>();
auto input = ControllerInput(buffer);
create_input(input, name, topic_name);
input_message_pairs_.insert_or_assign(
name, modulo_core::communication::make_shared_message_pair(std::make_shared<T>(), get_node()->get_clock()));
if constexpr (modulo_core::concepts::CustomT<T>) {
auto buffer = std::make_shared<realtime_tools::RealtimeBuffer<std::shared_ptr<T>>>();
auto input = ControllerInput(buffer);
auto parsed_name = validate_and_declare_signal(name, "input", topic_name);
if (!parsed_name.empty()) {
inputs_.insert_or_assign(parsed_name, input);
custom_input_configuration_callables_.insert_or_assign(
name, [this](const std::string& name, const std::string& topic) {
auto subscription =
get_node()->create_subscription<T>(topic, qos_, [this, name](const std::shared_ptr<T> message) {
auto buffer_variant = std::get<std::any>(inputs_.at(name).buffer);
auto buffer = std::any_cast<std::shared_ptr<realtime_tools::RealtimeBuffer<std::shared_ptr<T>>>>(
buffer_variant);
buffer->writeFromNonRT(message);
inputs_.at(name).timestamp = std::chrono::steady_clock::now();
});
subscriptions_.push_back(subscription);
});
}
} else {
auto buffer = realtime_tools::RealtimeBuffer<std::shared_ptr<modulo_core::EncodedState>>();
auto input = ControllerInput(buffer);
auto parsed_name = validate_and_declare_signal(name, "input", topic_name);
if (!parsed_name.empty()) {
inputs_.insert_or_assign(parsed_name, input);
input_message_pairs_.insert_or_assign(
parsed_name,
modulo_core::communication::make_shared_message_pair(std::make_shared<T>(), get_node()->get_clock()));
}
}
}

template<>
Expand Down Expand Up @@ -569,8 +605,22 @@ BaseControllerInterface::create_subscription(const std::string& name, const std:

template<typename T>
inline void BaseControllerInterface::add_output(const std::string& name, const std::string& topic_name) {
std::shared_ptr<state_representation::State> state_ptr = std::make_shared<T>();
create_output(EncodedStatePublishers(state_ptr, {}, {}), name, topic_name);
if constexpr (modulo_core::concepts::CustomT<T>) {
typedef std::pair<std::shared_ptr<rclcpp::Publisher<T>>, realtime_tools::RealtimePublisherSharedPtr<T>> PublisherT;
auto parsed_name = validate_and_declare_signal(name, "output", topic_name);
if (!parsed_name.empty()) {
outputs_.insert_or_assign(parsed_name, PublisherT());
custom_output_configuration_callables_.insert_or_assign(
name, [this](CustomPublishers& pub, const std::string& topic) {
auto publisher = get_node()->create_publisher<T>(topic, qos_);
pub.first = publisher;
pub.second = std::make_shared<realtime_tools::RealtimePublisher<T>>(publisher);
});
}
} else {
std::shared_ptr<state_representation::State> state_ptr = std::make_shared<T>();
create_output(EncodedStatePublishers(state_ptr, {}, {}), name, topic_name);
}
}

template<>
Expand Down Expand Up @@ -604,33 +654,45 @@ inline std::optional<T> BaseControllerInterface::read_input(const std::string& n
if (!check_input_valid(name)) {
return {};
}
auto message =
**std::get<realtime_tools::RealtimeBuffer<std::shared_ptr<modulo_core::EncodedState>>>(inputs_.at(name).buffer)
.readFromNonRT();
std::shared_ptr<state_representation::State> state;
try {
auto message_pair = input_message_pairs_.at(name);
message_pair->read<modulo_core::EncodedState, state_representation::State>(message);
state = message_pair->get_message_pair<modulo_core::EncodedState, state_representation::State>()->get_data();
} catch (const std::exception& ex) {
RCLCPP_WARN_THROTTLE(
get_node()->get_logger(), *get_node()->get_clock(), 1000,
"Could not read EncodedState message on input '%s': %s", name.c_str(), ex.what());
return {};
}
if (state->is_empty()) {

if constexpr (modulo_core::concepts::CustomT<T>) {
try {
auto buffer_variant = std::get<std::any>(inputs_.at(name).buffer);
auto buffer = std::any_cast<std::shared_ptr<realtime_tools::RealtimeBuffer<std::shared_ptr<T>>>>(buffer_variant);
return **(buffer->readFromNonRT());
} catch (const std::bad_any_cast& ex) {
RCLCPP_ERROR(get_node()->get_logger(), "Failed to read custom input: %s", ex.what());
}
return {};
}
auto cast_ptr = std::dynamic_pointer_cast<T>(state);
if (cast_ptr != nullptr) {
return *cast_ptr;
} else {
RCLCPP_WARN_THROTTLE(
get_node()->get_logger(), *get_node()->get_clock(), 1000,
"Dynamic cast of message on input '%s' from type '%s' to type '%s' failed.", name.c_str(),
get_state_type_name(state->get_type()).c_str(), get_state_type_name(T().get_type()).c_str());
auto message =
**std::get<realtime_tools::RealtimeBuffer<std::shared_ptr<modulo_core::EncodedState>>>(inputs_.at(name).buffer)
.readFromNonRT();
std::shared_ptr<state_representation::State> state;
try {
auto message_pair = input_message_pairs_.at(name);
message_pair->read<modulo_core::EncodedState, state_representation::State>(message);
state = message_pair->get_message_pair<modulo_core::EncodedState, state_representation::State>()->get_data();
} catch (const std::exception& ex) {
RCLCPP_WARN_THROTTLE(
get_node()->get_logger(), *get_node()->get_clock(), 1000,
"Could not read EncodedState message on input '%s': %s", name.c_str(), ex.what());
return {};
}
if (state->is_empty()) {
return {};
}
auto cast_ptr = std::dynamic_pointer_cast<T>(state);
if (cast_ptr != nullptr) {
return *cast_ptr;
} else {
RCLCPP_WARN_THROTTLE(
get_node()->get_logger(), *get_node()->get_clock(), 1000,
"Dynamic cast of message on input '%s' from type '%s' to type '%s' failed.", name.c_str(),
get_state_type_name(state->get_type()).c_str(), get_state_type_name(T().get_type()).c_str());
}
return {};
}
return {};
}

template<>
Expand Down Expand Up @@ -689,44 +751,71 @@ inline std::optional<std::string> BaseControllerInterface::read_input<std::strin

template<typename T>
inline void BaseControllerInterface::write_output(const std::string& name, const T& data) {
if (data.is_empty()) {
RCLCPP_DEBUG_THROTTLE(
get_node()->get_logger(), *get_node()->get_clock(), 1000,
"Skipping publication of output '%s' due to emptiness of state", name.c_str());
return;
}
if (outputs_.find(name) == outputs_.end()) {
RCLCPP_WARN_THROTTLE(
get_node()->get_logger(), *get_node()->get_clock(), 1000, "Could not find output '%s'", name.c_str());
return;
}
EncodedStatePublishers publishers;
try {
publishers = std::get<EncodedStatePublishers>(outputs_.at(name));
} catch (const std::bad_variant_access&) {
RCLCPP_WARN_THROTTLE(
get_node()->get_logger(), *get_node()->get_clock(), 1000,
"Could not retrieve publisher for output '%s': Invalid output type", name.c_str());
return;
}
if (const auto output_type = std::get<0>(publishers)->get_type(); output_type != data.get_type()) {
RCLCPP_WARN_THROTTLE(
get_node()->get_logger(), *get_node()->get_clock(), 1000,
"Skipping publication of output '%s' due to wrong data type (expected '%s', got '%s')",
state_representation::get_state_type_name(output_type).c_str(),
state_representation::get_state_type_name(data.get_type()).c_str(), name.c_str());
return;
}
auto rt_pub = std::get<2>(publishers);
if (rt_pub && rt_pub->trylock()) {

if constexpr (modulo_core::concepts::CustomT<T>) {
CustomPublishers publishers;
try {
modulo_core::translators::write_message<T>(rt_pub->msg_, data, get_node()->get_clock()->now());
} catch (const modulo_core::exceptions::MessageTranslationException& ex) {
publishers = std::get<CustomPublishers>(outputs_.at(name));
} catch (const std::bad_variant_access&) {
RCLCPP_WARN_THROTTLE(
get_node()->get_logger(), *get_node()->get_clock(), 1000,
"Could not retrieve publisher for output '%s': Invalid output type", name.c_str());
return;
}

std::shared_ptr<realtime_tools::RealtimePublisher<T>> rt_pub;
try {
rt_pub = std::any_cast<std::shared_ptr<realtime_tools::RealtimePublisher<T>>>(publishers.second);
} catch (const std::bad_any_cast& ex) {
RCLCPP_ERROR_THROTTLE(
get_node()->get_logger(), *get_node()->get_clock(), 1000, "Failed to publish output '%s': %s", name.c_str(),
ex.what());
get_node()->get_logger(), *get_node()->get_clock(), 1000,
"Skipping publication of output '%s' due to wrong data type: %s", name.c_str(), ex.what());
return;
}
if (rt_pub && rt_pub->trylock()) {
rt_pub->msg_ = data;
rt_pub->unlockAndPublish();
}
} else {
if (data.is_empty()) {
RCLCPP_DEBUG_THROTTLE(
get_node()->get_logger(), *get_node()->get_clock(), 1000,
"Skipping publication of output '%s' due to emptiness of state", name.c_str());
return;
}
EncodedStatePublishers publishers;
try {
publishers = std::get<EncodedStatePublishers>(outputs_.at(name));
} catch (const std::bad_variant_access&) {
RCLCPP_WARN_THROTTLE(
get_node()->get_logger(), *get_node()->get_clock(), 1000,
"Could not retrieve publisher for output '%s': Invalid output type", name.c_str());
return;
}
if (const auto output_type = std::get<0>(publishers)->get_type(); output_type != data.get_type()) {
RCLCPP_WARN_THROTTLE(
get_node()->get_logger(), *get_node()->get_clock(), 1000,
"Skipping publication of output '%s' due to wrong data type (expected '%s', got '%s')",
state_representation::get_state_type_name(output_type).c_str(),
state_representation::get_state_type_name(data.get_type()).c_str(), name.c_str());
return;
}
auto rt_pub = std::get<2>(publishers);
if (rt_pub && rt_pub->trylock()) {
try {
modulo_core::translators::write_message<T>(rt_pub->msg_, data, get_node()->get_clock()->now());
} catch (const modulo_core::exceptions::MessageTranslationException& ex) {
RCLCPP_ERROR_THROTTLE(
get_node()->get_logger(), *get_node()->get_clock(), 1000, "Failed to publish output '%s': %s", name.c_str(),
ex.what());
}
rt_pub->unlockAndPublish();
}
rt_pub->unlockAndPublish();
}
}

Expand Down
14 changes: 11 additions & 3 deletions source/modulo_controllers/src/BaseControllerInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ void BaseControllerInterface::create_input(
const ControllerInput& input, const std::string& name, const std::string& topic_name) {
auto parsed_name = validate_and_declare_signal(name, "input", topic_name);
if (!parsed_name.empty()) {
inputs_.insert_or_assign(name, input);
inputs_.insert_or_assign(parsed_name, input);
}
}

Expand All @@ -357,6 +357,9 @@ void BaseControllerInterface::add_inputs() {
},
[&](const realtime_tools::RealtimeBuffer<std::shared_ptr<std_msgs::msg::String>>&) {
subscriptions_.push_back(create_subscription<std_msgs::msg::String>(name, topic));
},
[&](const std::any&) {
custom_input_configuration_callables_.at(name)(name, topic);
}},
input.buffer);
} catch (const std::exception& ex) {
Expand All @@ -369,7 +372,7 @@ void BaseControllerInterface::create_output(
const PublisherVariant& publishers, const std::string& name, const std::string& topic_name) {
auto parsed_name = validate_and_declare_signal(name, "output", topic_name);
if (!parsed_name.empty()) {
outputs_.insert_or_assign(name, publishers);
outputs_.insert_or_assign(parsed_name, publishers);
}
}

Expand Down Expand Up @@ -403,10 +406,15 @@ void BaseControllerInterface::add_outputs() {
[&](StringPublishers& pub) {
pub.first = get_node()->create_publisher<std_msgs::msg::String>(topic, qos_);
pub.second = std::make_shared<realtime_tools::RealtimePublisher<std_msgs::msg::String>>(pub.first);
},
[&](CustomPublishers& pub) {
custom_output_configuration_callables_.at(name)(pub, name);
}},
publishers);
} catch (const std::bad_any_cast& ex) {
RCLCPP_ERROR(get_node()->get_logger(), "Failed to add custom output '%s': %s", name.c_str(), ex.what());
} catch (const std::exception& ex) {
RCLCPP_ERROR(get_node()->get_logger(), "Failed to add input '%s': %s", name.c_str(), ex.what());
RCLCPP_ERROR(get_node()->get_logger(), "Failed to add output '%s': %s", name.c_str(), ex.what());
}
}
}
Expand Down
Loading

0 comments on commit aa22a8c

Please sign in to comment.