Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix race conditions in rmw_wait and map queries to clients #153

Merged
merged 7 commits into from
Apr 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions rmw_zenoh_cpp/src/detail/event.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ void EventsManager::add_new_event(
///=============================================================================
void EventsManager::attach_event_condition(
rmw_zenoh_event_type_t event_id,
std::mutex * condition_mutex,
std::condition_variable * condition_variable)
{
if (event_id > ZENOH_EVENT_ID_MAX) {
Expand All @@ -194,7 +195,8 @@ void EventsManager::attach_event_condition(
return;
}

std::lock_guard<std::mutex> lock(event_condition_mutex_);
std::lock_guard<std::mutex> lock(update_event_condition_mutex_);
event_condition_mutexes_[event_id] = condition_mutex;
event_conditions_[event_id] = condition_variable;
}

Expand All @@ -209,7 +211,8 @@ void EventsManager::detach_event_condition(rmw_zenoh_event_type_t event_id)
return;
}

std::lock_guard<std::mutex> lock(event_condition_mutex_);
std::lock_guard<std::mutex> lock(update_event_condition_mutex_);
event_condition_mutexes_[event_id] = nullptr;
event_conditions_[event_id] = nullptr;
}

Expand All @@ -224,8 +227,9 @@ void EventsManager::notify_event(rmw_zenoh_event_type_t event_id)
return;
}

std::lock_guard<std::mutex> lock(event_condition_mutex_);
std::lock_guard<std::mutex> lock(update_event_condition_mutex_);
if (event_conditions_[event_id] != nullptr) {
std::lock_guard<std::mutex> cvlk(*event_condition_mutexes_[event_id]);
event_conditions_[event_id]->notify_one();
}
}
4 changes: 3 additions & 1 deletion rmw_zenoh_cpp/src/detail/event.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ class EventsManager
/// @param condition_variable to attach.
void attach_event_condition(
rmw_zenoh_event_type_t event_id,
std::mutex * condition_mutex,
std::condition_variable * condition_variable);

/// @brief Detach the condition variable provided by rmw_wait.
Expand All @@ -154,7 +155,8 @@ class EventsManager
/// Mutex to lock when read/writing members.
mutable std::mutex event_mutex_;
/// Mutex to lock for event_condition.
mutable std::mutex event_condition_mutex_;
mutable std::mutex update_event_condition_mutex_;
std::mutex * event_condition_mutexes_[ZENOH_EVENT_ID_MAX + 1]{nullptr};
/// Condition variable to attach for event notifications.
std::condition_variable * event_conditions_[ZENOH_EVENT_ID_MAX + 1]{nullptr};
/// User callback that can be set via data_callback_mgr.set_callback().
Expand Down
7 changes: 6 additions & 1 deletion rmw_zenoh_cpp/src/detail/guard_condition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,21 +33,26 @@ void GuardCondition::trigger()
has_triggered_ = true;

if (condition_variable_ != nullptr) {
std::lock_guard<std::mutex> cvlk(*condition_mutex_);
condition_variable_->notify_one();
}
}

///==============================================================================
void GuardCondition::attach_condition(std::condition_variable * condition_variable)
void GuardCondition::attach_condition(
std::mutex * condition_mutex,
std::condition_variable * condition_variable)
{
std::lock_guard<std::mutex> lock(internal_mutex_);
condition_mutex_ = condition_mutex;
condition_variable_ = condition_variable;
}

///==============================================================================
void GuardCondition::detach_condition()
{
std::lock_guard<std::mutex> lock(internal_mutex_);
condition_mutex_ = nullptr;
condition_variable_ = nullptr;
}

Expand Down
5 changes: 3 additions & 2 deletions rmw_zenoh_cpp/src/detail/guard_condition.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class GuardCondition final
// Sets has_triggered_ to true and calls notify_one() on condition_variable_ if set.
void trigger();

void attach_condition(std::condition_variable * condition_variable);
void attach_condition(std::mutex * condition_mutex, std::condition_variable * condition_variable);

void detach_condition();

Expand All @@ -38,7 +38,8 @@ class GuardCondition final
private:
mutable std::mutex internal_mutex_;
std::atomic_bool has_triggered_;
std::condition_variable * condition_variable_;
std::mutex * condition_mutex_{nullptr};
std::condition_variable * condition_variable_{nullptr};
};

#endif // DETAIL__GUARD_CONDITION_HPP_
107 changes: 86 additions & 21 deletions rmw_zenoh_cpp/src/detail/rmw_data_types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <memory>
#include <mutex>
#include <optional>
#include <sstream>
#include <string>
#include <utility>

Expand Down Expand Up @@ -62,25 +63,32 @@ size_t rmw_publisher_data_t::get_next_sequence_number()
}

///=============================================================================
void rmw_subscription_data_t::attach_condition(std::condition_variable * condition_variable)
void rmw_subscription_data_t::attach_condition(
std::mutex * condition_mutex,
std::condition_variable * condition_variable)
{
std::lock_guard<std::mutex> lock(condition_mutex_);
std::lock_guard<std::mutex> lock(update_condition_mutex_);
condition_mutex_ = condition_mutex;
condition_ = condition_variable;
}

///=============================================================================
void rmw_subscription_data_t::notify()
{
std::lock_guard<std::mutex> lock(condition_mutex_);
std::lock_guard<std::mutex> lock(update_condition_mutex_);
if (condition_ != nullptr) {
// We also need to take the mutex for the condition_variable; see the comment
// in rmw_wait for more information
std::lock_guard<std::mutex> cvlk(*condition_mutex_);
condition_->notify_one();
}
}

///=============================================================================
void rmw_subscription_data_t::detach_condition()
{
std::lock_guard<std::mutex> lock(condition_mutex_);
std::lock_guard<std::mutex> lock(update_condition_mutex_);
condition_mutex_ = nullptr;
condition_ = nullptr;
}

Expand Down Expand Up @@ -149,16 +157,20 @@ bool rmw_service_data_t::query_queue_is_empty() const
}

///=============================================================================
void rmw_service_data_t::attach_condition(std::condition_variable * condition_variable)
void rmw_service_data_t::attach_condition(
std::mutex * condition_mutex,
std::condition_variable * condition_variable)
{
std::lock_guard<std::mutex> lock(condition_mutex_);
std::lock_guard<std::mutex> lock(update_condition_mutex_);
condition_mutex_ = condition_mutex;
condition_ = condition_variable;
}

///=============================================================================
void rmw_service_data_t::detach_condition()
{
std::lock_guard<std::mutex> lock(condition_mutex_);
std::lock_guard<std::mutex> lock(update_condition_mutex_);
condition_mutex_ = nullptr;
condition_ = nullptr;
}

Expand All @@ -179,8 +191,11 @@ std::unique_ptr<ZenohQuery> rmw_service_data_t::pop_next_query()
///=============================================================================
void rmw_service_data_t::notify()
{
std::lock_guard<std::mutex> lock(condition_mutex_);
std::lock_guard<std::mutex> lock(update_condition_mutex_);
if (condition_ != nullptr) {
// We also need to take the mutex for the condition_variable; see the comment
// in rmw_wait for more information
std::lock_guard<std::mutex> cvlk(*condition_mutex_);
condition_->notify_one();
}
}
Expand Down Expand Up @@ -208,40 +223,86 @@ void rmw_service_data_t::add_new_query(std::unique_ptr<ZenohQuery> query)
notify();
}

static size_t hash_gid(const rmw_request_id_t & request_id)
{
std::stringstream hash_str;
hash_str << std::hex;
size_t i = 0;
for (; i < (RMW_GID_STORAGE_SIZE - 1); i++) {
hash_str << static_cast<int>(request_id.writer_guid[i]);
}
return std::hash<std::string>{}(hash_str.str());
}

///=============================================================================
bool rmw_service_data_t::add_to_query_map(
int64_t sequence_number, std::unique_ptr<ZenohQuery> query)
const rmw_request_id_t & request_id, std::unique_ptr<ZenohQuery> query)
{
size_t hash = hash_gid(request_id);

std::lock_guard<std::mutex> lock(sequence_to_query_map_mutex_);
if (sequence_to_query_map_.find(sequence_number) != sequence_to_query_map_.end()) {
return false;

std::unordered_map<size_t, SequenceToQuery>::iterator it =
sequence_to_query_map_.find(hash);

if (it == sequence_to_query_map_.end()) {
SequenceToQuery stq;

sequence_to_query_map_.insert(std::make_pair(hash, std::move(stq)));

it = sequence_to_query_map_.find(hash);
} else {
// Client already in the map

if (it->second.find(request_id.sequence_number) != it->second.end()) {
return false;
}
}
sequence_to_query_map_.emplace(
std::pair(sequence_number, std::move(query)));

it->second.insert(
std::make_pair(request_id.sequence_number, std::move(query)));

return true;
}

///=============================================================================
std::unique_ptr<ZenohQuery> rmw_service_data_t::take_from_query_map(int64_t sequence_number)
std::unique_ptr<ZenohQuery> rmw_service_data_t::take_from_query_map(
const rmw_request_id_t & request_id)
{
size_t hash = hash_gid(request_id);

std::lock_guard<std::mutex> lock(sequence_to_query_map_mutex_);
auto query_it = sequence_to_query_map_.find(sequence_number);
if (query_it == sequence_to_query_map_.end()) {

std::unordered_map<size_t, SequenceToQuery>::iterator it = sequence_to_query_map_.find(hash);

if (it == sequence_to_query_map_.end()) {
return nullptr;
}

SequenceToQuery::iterator query_it = it->second.find(request_id.sequence_number);

if (query_it == it->second.end()) {
return nullptr;
}

std::unique_ptr<ZenohQuery> query = std::move(query_it->second);
sequence_to_query_map_.erase(query_it);
it->second.erase(query_it);

if (sequence_to_query_map_[hash].size() == 0) {
sequence_to_query_map_.erase(hash);
}

return query;
}

///=============================================================================
void rmw_client_data_t::notify()
{
std::lock_guard<std::mutex> lock(condition_mutex_);
std::lock_guard<std::mutex> lock(update_condition_mutex_);
if (condition_ != nullptr) {
// We also need to take the mutex for the condition_variable; see the comment
// in rmw_wait for more information
std::lock_guard<std::mutex> cvlk(*condition_mutex_);
condition_->notify_one();
}
}
Expand Down Expand Up @@ -278,16 +339,20 @@ bool rmw_client_data_t::reply_queue_is_empty() const
}

///=============================================================================
void rmw_client_data_t::attach_condition(std::condition_variable * condition_variable)
void rmw_client_data_t::attach_condition(
std::mutex * condition_mutex,
std::condition_variable * condition_variable)
{
std::lock_guard<std::mutex> lock(condition_mutex_);
std::lock_guard<std::mutex> lock(update_condition_mutex_);
condition_mutex_ = condition_mutex;
condition_ = condition_variable;
}

///=============================================================================
void rmw_client_data_t::detach_condition()
{
std::lock_guard<std::mutex> lock(condition_mutex_);
std::lock_guard<std::mutex> lock(update_condition_mutex_);
condition_mutex_ = nullptr;
condition_ = nullptr;
}

Expand Down
25 changes: 14 additions & 11 deletions rmw_zenoh_cpp/src/detail/rmw_data_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
#include <optional>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <variant>
#include <vector>
Expand Down Expand Up @@ -173,7 +172,7 @@ class rmw_subscription_data_t final
MessageTypeSupport * type_support;
rmw_context_t * context;

void attach_condition(std::condition_variable * condition_variable);
void attach_condition(std::mutex * condition_mutex, std::condition_variable * condition_variable);

void detach_condition();

Expand All @@ -192,8 +191,9 @@ class rmw_subscription_data_t final

void notify();

std::mutex * condition_mutex_{nullptr};
std::condition_variable * condition_{nullptr};
std::mutex condition_mutex_;
std::mutex update_condition_mutex_;
};


Expand Down Expand Up @@ -244,17 +244,17 @@ class rmw_service_data_t final

bool query_queue_is_empty() const;

void attach_condition(std::condition_variable * condition_variable);
void attach_condition(std::mutex * condition_mutex, std::condition_variable * condition_variable);

void detach_condition();

std::unique_ptr<ZenohQuery> pop_next_query();

void add_new_query(std::unique_ptr<ZenohQuery> query);

bool add_to_query_map(int64_t sequence_number, std::unique_ptr<ZenohQuery> query);
bool add_to_query_map(const rmw_request_id_t & request_id, std::unique_ptr<ZenohQuery> query);

std::unique_ptr<ZenohQuery> take_from_query_map(int64_t sequence_number);
std::unique_ptr<ZenohQuery> take_from_query_map(const rmw_request_id_t & request_id);

DataCallbackManager data_callback_mgr;

Expand All @@ -265,12 +265,14 @@ class rmw_service_data_t final
std::deque<std::unique_ptr<ZenohQuery>> query_queue_;
mutable std::mutex query_queue_mutex_;

// Map to store the sequence_number -> query_id
std::unordered_map<int64_t, std::unique_ptr<ZenohQuery>> sequence_to_query_map_;
// Map to store the sequence_number (as given by the client) -> ZenohQuery
using SequenceToQuery = std::unordered_map<int64_t, std::unique_ptr<ZenohQuery>>;
std::unordered_map<size_t, SequenceToQuery> sequence_to_query_map_;
std::mutex sequence_to_query_map_mutex_;

std::mutex * condition_mutex_{nullptr};
std::condition_variable * condition_{nullptr};
std::mutex condition_mutex_;
std::mutex update_condition_mutex_;
};

///=============================================================================
Expand Down Expand Up @@ -320,7 +322,7 @@ class rmw_client_data_t final

bool reply_queue_is_empty() const;

void attach_condition(std::condition_variable * condition_variable);
void attach_condition(std::mutex * condition_mutex, std::condition_variable * condition_variable);

void detach_condition();

Expand All @@ -334,8 +336,9 @@ class rmw_client_data_t final
size_t sequence_number_{1};
std::mutex sequence_number_mutex_;

std::mutex * condition_mutex_{nullptr};
std::condition_variable * condition_{nullptr};
std::mutex condition_mutex_;
std::mutex update_condition_mutex_;

std::deque<std::unique_ptr<ZenohReply>> reply_queue_;
mutable std::mutex reply_queue_mutex_;
Expand Down
Loading
Loading