From 36539bd9be8178f58a85d912ee7e9cb96295f74b Mon Sep 17 00:00:00 2001 From: Chris Lalancette Date: Wed, 20 Nov 2024 21:40:42 +0000 Subject: [PATCH] Fix UB in ClientData stuff. The num_in_flight stuff was *still* wrong here. First of all, we forgot to increment num_in_flight when actually kicking off a new query. Once we did that, we had to change the lock in NodeData to a recursive one, since the call to delete_client_data from ClientData could be called recursively. And then finally we had to drop the ClientData lock before the delete_client_data, since we are about to delete ourselves and the unlock would have been UB. Signed-off-by: Chris Lalancette --- rmw_zenoh_cpp/src/detail/rmw_client_data.cpp | 5 +++- rmw_zenoh_cpp/src/detail/rmw_node_data.cpp | 30 ++++++++++---------- rmw_zenoh_cpp/src/detail/rmw_node_data.hpp | 2 +- 3 files changed, 20 insertions(+), 17 deletions(-) diff --git a/rmw_zenoh_cpp/src/detail/rmw_client_data.cpp b/rmw_zenoh_cpp/src/detail/rmw_client_data.cpp index cf0fcb4e..7bfe4c16 100644 --- a/rmw_zenoh_cpp/src/detail/rmw_client_data.cpp +++ b/rmw_zenoh_cpp/src/detail/rmw_client_data.cpp @@ -471,6 +471,7 @@ rmw_ret_t ClientData::send_request( opts.value.payload = z_bytes_t{data_length, reinterpret_cast(request_bytes)}; // TODO(Yadunund): Once we switch to zenoh-cpp with lambda closures, // capture shared_from_this() instead of this. + num_in_flight_++; z_owned_closure_reply_t zn_closure_reply = z_closure(client_data_handler, client_data_drop, this); z_get( @@ -563,7 +564,7 @@ bool ClientData::shutdown_and_query_in_flight() ///============================================================================= void ClientData::decrement_in_flight_and_conditionally_remove() { - std::lock_guard lock(mutex_); + std::unique_lock lock(mutex_); --num_in_flight_; if (is_shutdown_ && num_in_flight_ == 0) { @@ -575,6 +576,8 @@ void ClientData::decrement_in_flight_and_conditionally_remove() if (node_data == nullptr) { return; } + // We have to unlock here since we are about to delete ourself, and thus the unlock would be UB. + lock.unlock(); node_data->delete_client_data(rmw_client_); } } diff --git a/rmw_zenoh_cpp/src/detail/rmw_node_data.cpp b/rmw_zenoh_cpp/src/detail/rmw_node_data.cpp index bd3f3f6e..e05fb3e7 100644 --- a/rmw_zenoh_cpp/src/detail/rmw_node_data.cpp +++ b/rmw_zenoh_cpp/src/detail/rmw_node_data.cpp @@ -115,7 +115,7 @@ NodeData::~NodeData() ///============================================================================= std::size_t NodeData::id() const { - std::lock_guard lock(mutex_); + std::lock_guard lock(mutex_); return id_; } @@ -128,7 +128,7 @@ bool NodeData::create_pub_data( const rosidl_message_type_support_t * type_support, const rmw_qos_profile_t * qos_profile) { - std::lock_guard lock_guard(mutex_); + std::lock_guard lock_guard(mutex_); if (is_shutdown_) { RMW_ZENOH_LOG_ERROR_NAMED( "rmw_zenoh_cpp", @@ -169,7 +169,7 @@ bool NodeData::create_pub_data( ///============================================================================= PublisherDataPtr NodeData::get_pub_data(const rmw_publisher_t * const publisher) { - std::lock_guard lock_guard(mutex_); + std::lock_guard lock_guard(mutex_); auto it = pubs_.find(publisher); if (it == pubs_.end()) { return nullptr; @@ -181,7 +181,7 @@ PublisherDataPtr NodeData::get_pub_data(const rmw_publisher_t * const publisher) ///============================================================================= void NodeData::delete_pub_data(const rmw_publisher_t * const publisher) { - std::lock_guard lock_guard(mutex_); + std::lock_guard lock_guard(mutex_); pubs_.erase(publisher); } @@ -195,7 +195,7 @@ bool NodeData::create_sub_data( const rosidl_message_type_support_t * type_support, const rmw_qos_profile_t * qos_profile) { - std::lock_guard lock_guard(mutex_); + std::lock_guard lock_guard(mutex_); if (is_shutdown_) { RMW_ZENOH_LOG_ERROR_NAMED( "rmw_zenoh_cpp", @@ -237,7 +237,7 @@ bool NodeData::create_sub_data( ///============================================================================= SubscriptionDataPtr NodeData::get_sub_data(const rmw_subscription_t * const subscription) { - std::lock_guard lock_guard(mutex_); + std::lock_guard lock_guard(mutex_); auto it = subs_.find(subscription); if (it == subs_.end()) { return nullptr; @@ -249,7 +249,7 @@ SubscriptionDataPtr NodeData::get_sub_data(const rmw_subscription_t * const subs ///============================================================================= void NodeData::delete_sub_data(const rmw_subscription_t * const subscription) { - std::lock_guard lock_guard(mutex_); + std::lock_guard lock_guard(mutex_); subs_.erase(subscription); } @@ -262,7 +262,7 @@ bool NodeData::create_service_data( const rosidl_service_type_support_t * type_supports, const rmw_qos_profile_t * qos_profile) { - std::lock_guard lock_guard(mutex_); + std::lock_guard lock_guard(mutex_); if (is_shutdown_) { RMW_ZENOH_LOG_ERROR_NAMED( "rmw_zenoh_cpp", @@ -303,7 +303,7 @@ bool NodeData::create_service_data( ///============================================================================= ServiceDataPtr NodeData::get_service_data(const rmw_service_t * const service) { - std::lock_guard lock_guard(mutex_); + std::lock_guard lock_guard(mutex_); auto it = services_.find(service); if (it == services_.end()) { return nullptr; @@ -315,7 +315,7 @@ ServiceDataPtr NodeData::get_service_data(const rmw_service_t * const service) ///============================================================================= void NodeData::delete_service_data(const rmw_service_t * const service) { - std::lock_guard lock_guard(mutex_); + std::lock_guard lock_guard(mutex_); services_.erase(service); } @@ -329,7 +329,7 @@ bool NodeData::create_client_data( const rosidl_service_type_support_t * type_supports, const rmw_qos_profile_t * qos_profile) { - std::lock_guard lock_guard(mutex_); + std::lock_guard lock_guard(mutex_); if (is_shutdown_) { RMW_ZENOH_LOG_ERROR_NAMED( "rmw_zenoh_cpp", @@ -371,7 +371,7 @@ bool NodeData::create_client_data( ///============================================================================= ClientDataPtr NodeData::get_client_data(const rmw_client_t * const client) { - std::lock_guard lock_guard(mutex_); + std::lock_guard lock_guard(mutex_); auto it = clients_.find(client); if (it == clients_.end()) { return nullptr; @@ -383,7 +383,7 @@ ClientDataPtr NodeData::get_client_data(const rmw_client_t * const client) ///============================================================================= void NodeData::delete_client_data(const rmw_client_t * const client) { - std::lock_guard lock_guard(mutex_); + std::lock_guard lock_guard(mutex_); auto client_it = clients_.find(client); if (client_it == clients_.end()) { return; @@ -396,7 +396,7 @@ void NodeData::delete_client_data(const rmw_client_t * const client) ///============================================================================= rmw_ret_t NodeData::shutdown() { - std::lock_guard lock(mutex_); + std::lock_guard lock(mutex_); rmw_ret_t ret = RMW_RET_OK; if (is_shutdown_) { return ret; @@ -463,7 +463,7 @@ rmw_ret_t NodeData::shutdown() // Check if the Node is shutdown. bool NodeData::is_shutdown() const { - std::lock_guard lock(mutex_); + std::lock_guard lock(mutex_); return is_shutdown_; } diff --git a/rmw_zenoh_cpp/src/detail/rmw_node_data.hpp b/rmw_zenoh_cpp/src/detail/rmw_node_data.hpp index f85b1366..5489e4b7 100644 --- a/rmw_zenoh_cpp/src/detail/rmw_node_data.hpp +++ b/rmw_zenoh_cpp/src/detail/rmw_node_data.hpp @@ -130,7 +130,7 @@ class NodeData final std::shared_ptr entity, zc_owned_liveliness_token_t token); // Internal mutex. - mutable std::mutex mutex_; + mutable std::recursive_mutex mutex_; // The rmw_node_t associated with this NodeData. const rmw_node_t * node_; // The entity id of this node as generated by get_next_entity_id().