From 10464e56d48c0f02d41202de9f0af0725da9a8ac Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Thu, 5 Oct 2023 22:42:29 -0700 Subject: [PATCH] Fix the concurrency bug inside UnorderedReceiver --- src/helpers/buffers/ordering_sender.rs | 27 +++++++++++++++----------- src/helpers/gateway/receive.rs | 18 +++++++++-------- 2 files changed, 26 insertions(+), 19 deletions(-) diff --git a/src/helpers/buffers/ordering_sender.rs b/src/helpers/buffers/ordering_sender.rs index 21a4a4657d..cf32abcb27 100644 --- a/src/helpers/buffers/ordering_sender.rs +++ b/src/helpers/buffers/ordering_sender.rs @@ -2,12 +2,12 @@ use std::{ borrow::Borrow, cmp::Ordering, collections::VecDeque, + fmt::{Debug, Formatter}, mem::drop, num::NonZeroUsize, pin::Pin, task::{Context, Poll}, }; -use std::fmt::{Debug, Formatter}; use futures::{task::Waker, Future, Stream}; use generic_array::GenericArray; @@ -152,8 +152,12 @@ impl std::fmt::Display for WakerRejected { impl Debug for WakerRejected { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "Adding waker is rejected because the expected position {} is behind actual {}. \ - Refresh your view and try again.", self.0, self.1) + write!( + f, + "Adding waker is rejected because the expected position {} is behind actual {}. \ + Refresh your view and try again.", + self.0, self.1 + ) } } @@ -169,7 +173,7 @@ impl WaitingShard { // this means this thread is out of sync and there was an update to channel's current // position. Accepting a waker could mean it will never be awakened. Rejecting this operation // will let the current thread to read the position again. - Err(WakerRejected::new(current, self.woken_at))? + Err(WakerRejected::new(current, self.woken_at))?; } // Each new addition will tend to have a larger index, so search backwards and @@ -336,7 +340,7 @@ impl OrderingSender { let curr = self.next.fetch_add(1, AcqRel); debug_assert_eq!(i, curr, "we just checked this"); } - break res + break res; } Ordering::Less => { // This is the hot path. Wait our turn. If our view of the world is obsolete @@ -354,7 +358,7 @@ impl OrderingSender { // be rejected because writer has moved the waiting shard position ahead and it won't match // the value of `self.next` read by the waiting thread. if let Ok(()) = self.waiting.add(curr, i, cx.waker()) { - break Poll::Pending + break Poll::Pending; } } } @@ -384,13 +388,15 @@ impl OrderingSender { /// The stream interface requires a mutable reference to the stream itself. /// That's not possible here as we create a ton of immutable references to this. /// This wrapper takes a trivial reference so that we can implement `Stream`. - #[cfg(test)] + #[cfg(all(test, any(unit_test, feature = "shuttle")))] fn as_stream(&self) -> OrderedStream<&Self> { OrderedStream { sender: self } } - #[cfg(test)] - pub(crate) fn as_rc_stream(self: crate::sync::Arc) -> OrderedStream> { + #[cfg(all(test, unit_test))] + pub(crate) fn as_rc_stream( + self: crate::sync::Arc, + ) -> OrderedStream> { OrderedStream { sender: self } } } @@ -481,7 +487,6 @@ mod test { sync::Arc, test_executor::run, }; - use crate::test_fixture::logging; fn sender() -> Arc { Arc::new(OrderingSender::new( @@ -653,8 +658,8 @@ mod test { /// This test is supposed to eventually hang if there is a concurrency bug inside `OrderingSender`. #[test] fn parallel_send() { - logging::setup(); const PARALLELISM: usize = 100; + run(|| async { let sender = Arc::new(OrderingSender::new( NonZeroUsize::new(PARALLELISM * ::Size::USIZE).unwrap(), diff --git a/src/helpers/gateway/receive.rs b/src/helpers/gateway/receive.rs index 4f50e7b489..282ff68e27 100644 --- a/src/helpers/gateway/receive.rs +++ b/src/helpers/gateway/receive.rs @@ -1,6 +1,6 @@ use std::marker::PhantomData; -use dashmap::DashMap; +use dashmap::{mapref::entry::Entry, DashMap}; use futures::Stream; use crate::{ @@ -65,13 +65,15 @@ impl Default for GatewayReceivers { impl GatewayReceivers { pub fn get_or_create UR>(&self, channel_id: &ChannelId, ctr: F) -> UR { - let receivers = &self.inner; - if let Some(recv) = receivers.get(channel_id) { - recv.clone() - } else { - let stream = ctr(); - receivers.insert(channel_id.clone(), stream.clone()); - stream + // TODO: raw entry API if it becomes available to avoid cloning the key + match self.inner.entry(channel_id.clone()) { + Entry::Occupied(entry) => entry.get().clone(), + Entry::Vacant(entry) => { + let stream = ctr(); + entry.insert(stream.clone()); + + stream + } } } }