Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix UB in ClientData stuff. #320

Merged
merged 1 commit into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion rmw_zenoh_cpp/src/detail/rmw_client_data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,7 @@ rmw_ret_t ClientData::send_request(
opts.value.payload = z_bytes_t{data_length, reinterpret_cast<const uint8_t *>(request_bytes)};
// TODO(Yadunund): Once we switch to zenoh-cpp with lambda closures,
// capture shared_from_this() instead of this.
num_in_flight_++;
z_owned_closure_reply_t zn_closure_reply =
z_closure(client_data_handler, client_data_drop, this);
z_get(
Expand Down Expand Up @@ -563,7 +564,7 @@ bool ClientData::shutdown_and_query_in_flight()
///=============================================================================
void ClientData::decrement_in_flight_and_conditionally_remove()
{
std::lock_guard<std::recursive_mutex> lock(mutex_);
std::unique_lock<std::recursive_mutex> lock(mutex_);
--num_in_flight_;

if (is_shutdown_ && num_in_flight_ == 0) {
Expand All @@ -575,6 +576,8 @@ void ClientData::decrement_in_flight_and_conditionally_remove()
if (node_data == nullptr) {
return;
}
// We have to unlock here since we are about to delete ourself, and thus the unlock would be UB.
lock.unlock();
node_data->delete_client_data(rmw_client_);
}
}
Expand Down
30 changes: 15 additions & 15 deletions rmw_zenoh_cpp/src/detail/rmw_node_data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ NodeData::~NodeData()
///=============================================================================
std::size_t NodeData::id() const
{
std::lock_guard<std::mutex> lock(mutex_);
std::lock_guard<std::recursive_mutex> lock(mutex_);
return id_;
}

Expand All @@ -128,7 +128,7 @@ bool NodeData::create_pub_data(
const rosidl_message_type_support_t * type_support,
const rmw_qos_profile_t * qos_profile)
{
std::lock_guard<std::mutex> lock_guard(mutex_);
std::lock_guard<std::recursive_mutex> lock_guard(mutex_);
if (is_shutdown_) {
RMW_ZENOH_LOG_ERROR_NAMED(
"rmw_zenoh_cpp",
Expand Down Expand Up @@ -169,7 +169,7 @@ bool NodeData::create_pub_data(
///=============================================================================
PublisherDataPtr NodeData::get_pub_data(const rmw_publisher_t * const publisher)
{
std::lock_guard<std::mutex> lock_guard(mutex_);
std::lock_guard<std::recursive_mutex> lock_guard(mutex_);
auto it = pubs_.find(publisher);
if (it == pubs_.end()) {
return nullptr;
Expand All @@ -181,7 +181,7 @@ PublisherDataPtr NodeData::get_pub_data(const rmw_publisher_t * const publisher)
///=============================================================================
void NodeData::delete_pub_data(const rmw_publisher_t * const publisher)
{
std::lock_guard<std::mutex> lock_guard(mutex_);
std::lock_guard<std::recursive_mutex> lock_guard(mutex_);
pubs_.erase(publisher);
}

Expand All @@ -195,7 +195,7 @@ bool NodeData::create_sub_data(
const rosidl_message_type_support_t * type_support,
const rmw_qos_profile_t * qos_profile)
{
std::lock_guard<std::mutex> lock_guard(mutex_);
std::lock_guard<std::recursive_mutex> lock_guard(mutex_);
if (is_shutdown_) {
RMW_ZENOH_LOG_ERROR_NAMED(
"rmw_zenoh_cpp",
Expand Down Expand Up @@ -237,7 +237,7 @@ bool NodeData::create_sub_data(
///=============================================================================
SubscriptionDataPtr NodeData::get_sub_data(const rmw_subscription_t * const subscription)
{
std::lock_guard<std::mutex> lock_guard(mutex_);
std::lock_guard<std::recursive_mutex> lock_guard(mutex_);
auto it = subs_.find(subscription);
if (it == subs_.end()) {
return nullptr;
Expand All @@ -249,7 +249,7 @@ SubscriptionDataPtr NodeData::get_sub_data(const rmw_subscription_t * const subs
///=============================================================================
void NodeData::delete_sub_data(const rmw_subscription_t * const subscription)
{
std::lock_guard<std::mutex> lock_guard(mutex_);
std::lock_guard<std::recursive_mutex> lock_guard(mutex_);
subs_.erase(subscription);
}

Expand All @@ -262,7 +262,7 @@ bool NodeData::create_service_data(
const rosidl_service_type_support_t * type_supports,
const rmw_qos_profile_t * qos_profile)
{
std::lock_guard<std::mutex> lock_guard(mutex_);
std::lock_guard<std::recursive_mutex> lock_guard(mutex_);
if (is_shutdown_) {
RMW_ZENOH_LOG_ERROR_NAMED(
"rmw_zenoh_cpp",
Expand Down Expand Up @@ -303,7 +303,7 @@ bool NodeData::create_service_data(
///=============================================================================
ServiceDataPtr NodeData::get_service_data(const rmw_service_t * const service)
{
std::lock_guard<std::mutex> lock_guard(mutex_);
std::lock_guard<std::recursive_mutex> lock_guard(mutex_);
auto it = services_.find(service);
if (it == services_.end()) {
return nullptr;
Expand All @@ -315,7 +315,7 @@ ServiceDataPtr NodeData::get_service_data(const rmw_service_t * const service)
///=============================================================================
void NodeData::delete_service_data(const rmw_service_t * const service)
{
std::lock_guard<std::mutex> lock_guard(mutex_);
std::lock_guard<std::recursive_mutex> lock_guard(mutex_);
services_.erase(service);
}

Expand All @@ -329,7 +329,7 @@ bool NodeData::create_client_data(
const rosidl_service_type_support_t * type_supports,
const rmw_qos_profile_t * qos_profile)
{
std::lock_guard<std::mutex> lock_guard(mutex_);
std::lock_guard<std::recursive_mutex> lock_guard(mutex_);
if (is_shutdown_) {
RMW_ZENOH_LOG_ERROR_NAMED(
"rmw_zenoh_cpp",
Expand Down Expand Up @@ -371,7 +371,7 @@ bool NodeData::create_client_data(
///=============================================================================
ClientDataPtr NodeData::get_client_data(const rmw_client_t * const client)
{
std::lock_guard<std::mutex> lock_guard(mutex_);
std::lock_guard<std::recursive_mutex> lock_guard(mutex_);
auto it = clients_.find(client);
if (it == clients_.end()) {
return nullptr;
Expand All @@ -383,7 +383,7 @@ ClientDataPtr NodeData::get_client_data(const rmw_client_t * const client)
///=============================================================================
void NodeData::delete_client_data(const rmw_client_t * const client)
{
std::lock_guard<std::mutex> lock_guard(mutex_);
std::lock_guard<std::recursive_mutex> lock_guard(mutex_);
auto client_it = clients_.find(client);
if (client_it == clients_.end()) {
return;
Expand All @@ -396,7 +396,7 @@ void NodeData::delete_client_data(const rmw_client_t * const client)
///=============================================================================
rmw_ret_t NodeData::shutdown()
{
std::lock_guard<std::mutex> lock(mutex_);
std::lock_guard<std::recursive_mutex> lock(mutex_);
rmw_ret_t ret = RMW_RET_OK;
if (is_shutdown_) {
return ret;
Expand Down Expand Up @@ -463,7 +463,7 @@ rmw_ret_t NodeData::shutdown()
// Check if the Node is shutdown.
bool NodeData::is_shutdown() const
{
std::lock_guard<std::mutex> lock(mutex_);
std::lock_guard<std::recursive_mutex> lock(mutex_);
return is_shutdown_;
}

Expand Down
2 changes: 1 addition & 1 deletion rmw_zenoh_cpp/src/detail/rmw_node_data.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ class NodeData final
std::shared_ptr<liveliness::Entity> entity,
zc_owned_liveliness_token_t token);
// Internal mutex.
mutable std::mutex mutex_;
mutable std::recursive_mutex mutex_;
// The rmw_node_t associated with this NodeData.
const rmw_node_t * node_;
// The entity id of this node as generated by get_next_entity_id().
Expand Down