Skip to content

Commit

Permalink
Fix the concurrency bug inside UnorderedReceiver
Browse files Browse the repository at this point in the history
  • Loading branch information
akoshelev committed Oct 6, 2023
1 parent 860f251 commit 10464e5
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 19 deletions.
27 changes: 16 additions & 11 deletions src/helpers/buffers/ordering_sender.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@ use std::{
borrow::Borrow,
cmp::Ordering,
collections::VecDeque,
fmt::{Debug, Formatter},
mem::drop,
num::NonZeroUsize,
pin::Pin,
task::{Context, Poll},
};
use std::fmt::{Debug, Formatter};

use futures::{task::Waker, Future, Stream};
use generic_array::GenericArray;
Expand Down Expand Up @@ -152,8 +152,12 @@ impl std::fmt::Display for WakerRejected {

impl Debug for WakerRejected {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "Adding waker is rejected because the expected position {} is behind actual {}. \
Refresh your view and try again.", self.0, self.1)
write!(
f,
"Adding waker is rejected because the expected position {} is behind actual {}. \
Refresh your view and try again.",
self.0, self.1
)
}
}

Expand All @@ -169,7 +173,7 @@ impl WaitingShard {
// this means this thread is out of sync and there was an update to channel's current
// position. Accepting a waker could mean it will never be awakened. Rejecting this operation
// will let the current thread to read the position again.
Err(WakerRejected::new(current, self.woken_at))?
Err(WakerRejected::new(current, self.woken_at))?;
}

// Each new addition will tend to have a larger index, so search backwards and
Expand Down Expand Up @@ -336,7 +340,7 @@ impl OrderingSender {
let curr = self.next.fetch_add(1, AcqRel);
debug_assert_eq!(i, curr, "we just checked this");
}
break res
break res;
}
Ordering::Less => {
// This is the hot path. Wait our turn. If our view of the world is obsolete
Expand All @@ -354,7 +358,7 @@ impl OrderingSender {
// be rejected because writer has moved the waiting shard position ahead and it won't match
// the value of `self.next` read by the waiting thread.
if let Ok(()) = self.waiting.add(curr, i, cx.waker()) {
break Poll::Pending
break Poll::Pending;
}
}
}
Expand Down Expand Up @@ -384,13 +388,15 @@ impl OrderingSender {
/// The stream interface requires a mutable reference to the stream itself.
/// That's not possible here as we create a ton of immutable references to this.
/// This wrapper takes a trivial reference so that we can implement `Stream`.
#[cfg(test)]
#[cfg(all(test, any(unit_test, feature = "shuttle")))]
fn as_stream(&self) -> OrderedStream<&Self> {
OrderedStream { sender: self }
}

#[cfg(test)]
pub(crate) fn as_rc_stream(self: crate::sync::Arc<Self>) -> OrderedStream<crate::sync::Arc<Self>> {
#[cfg(all(test, unit_test))]
pub(crate) fn as_rc_stream(
self: crate::sync::Arc<Self>,
) -> OrderedStream<crate::sync::Arc<Self>> {
OrderedStream { sender: self }
}
}
Expand Down Expand Up @@ -481,7 +487,6 @@ mod test {
sync::Arc,
test_executor::run,
};
use crate::test_fixture::logging;

fn sender() -> Arc<OrderingSender> {
Arc::new(OrderingSender::new(
Expand Down Expand Up @@ -653,8 +658,8 @@ mod test {
/// This test is supposed to eventually hang if there is a concurrency bug inside `OrderingSender`.
#[test]
fn parallel_send() {
logging::setup();
const PARALLELISM: usize = 100;

run(|| async {
let sender = Arc::new(OrderingSender::new(
NonZeroUsize::new(PARALLELISM * <Fp31 as Serializable>::Size::USIZE).unwrap(),
Expand Down
18 changes: 10 additions & 8 deletions src/helpers/gateway/receive.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::marker::PhantomData;

use dashmap::DashMap;
use dashmap::{mapref::entry::Entry, DashMap};
use futures::Stream;

use crate::{
Expand Down Expand Up @@ -65,13 +65,15 @@ impl<T: Transport> Default for GatewayReceivers<T> {

impl<T: Transport> GatewayReceivers<T> {
pub fn get_or_create<F: FnOnce() -> UR<T>>(&self, channel_id: &ChannelId, ctr: F) -> UR<T> {
let receivers = &self.inner;
if let Some(recv) = receivers.get(channel_id) {
recv.clone()
} else {
let stream = ctr();
receivers.insert(channel_id.clone(), stream.clone());
stream
// TODO: raw entry API if it becomes available to avoid cloning the key
match self.inner.entry(channel_id.clone()) {
Entry::Occupied(entry) => entry.get().clone(),
Entry::Vacant(entry) => {
let stream = ctr();
entry.insert(stream.clone());

stream
}
}
}
}

0 comments on commit 10464e5

Please sign in to comment.