From 362dd561e95a9812c00f80bf1c36c8ba5ad44f80 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Wed, 8 Nov 2023 14:58:09 -0800 Subject: [PATCH 01/27] Support multithreading in `seq_join`/`parallel_join` Support is currently behind a feature flag that is not enabled by default We use userspace concurrency to drive many futures in parallel by spawning tasks inside the executor. This model is not ideal for performance because memory loads will happen across thread boundaries and NUMA cores, but already gives 50% more throughput for OPRF version and 200% to old IPA. --- ipa-core/Cargo.toml | 3 + ipa-core/src/protocol/basics/reshare.rs | 2 +- .../modulus_conversion/convert_shares.rs | 32 +- .../protocol/sort/generate_permutation_opt.rs | 4 +- .../replicated/malicious/additive_share.rs | 2 +- ipa-core/src/secret_sharing/scheme.rs | 3 +- ipa-core/src/seq_join.rs | 592 +++++++++++++----- 7 files changed, 453 insertions(+), 185 deletions(-) diff --git a/ipa-core/Cargo.toml b/ipa-core/Cargo.toml index 68867c196..7ee527ec0 100644 --- a/ipa-core/Cargo.toml +++ b/ipa-core/Cargo.toml @@ -61,6 +61,8 @@ step-trace = ["descriptive-gate"] # of unit tests use it. Compact uses memory-efficient gates and is suitable for production. descriptive-gate = [] compact-gate = ["ipa-macros/compact-gate"] +# Enable using more than one thread for protocol execution. Most of the parallelism occurs at parallel/seq_join operations +multi-threading = ["async-scoped"] # Standalone aggregation protocol. We use IPA infra for communication # but it has nothing to do with IPA. @@ -73,6 +75,7 @@ ipa-macros = { version = "*", path = "../ipa-macros" } aes = "0.8.3" async-trait = "0.1.68" +async-scoped = { version = "0.7.1", features = ["use-tokio"], optional = true } axum = { version = "0.5.17", optional = true, features = ["http2"] } axum-server = { version = "0.5.1", optional = true, features = [ "rustls", diff --git a/ipa-core/src/protocol/basics/reshare.rs b/ipa-core/src/protocol/basics/reshare.rs index 2e9a868e6..70b65c1c3 100644 --- a/ipa-core/src/protocol/basics/reshare.rs +++ b/ipa-core/src/protocol/basics/reshare.rs @@ -42,7 +42,7 @@ use crate::{ /// `to_helper` = (`rand_left`, `rand_right`) = (r0, r1) /// `to_helper.right` = (`rand_right`, part1 + part2) = (r0, part1 + part2) #[async_trait] -pub trait Reshare: Sized { +pub trait Reshare: Sized + 'static { async fn reshare<'fut>( &self, ctx: C, diff --git a/ipa-core/src/protocol/modulus_conversion/convert_shares.rs b/ipa-core/src/protocol/modulus_conversion/convert_shares.rs index 87df09abf..dae9ae8c1 100644 --- a/ipa-core/src/protocol/modulus_conversion/convert_shares.rs +++ b/ipa-core/src/protocol/modulus_conversion/convert_shares.rs @@ -292,17 +292,17 @@ where /// # Panics /// If the total record count on the context is unspecified. #[tracing::instrument(name = "modulus_conversion", skip_all, fields(bits = ?bit_range, gate = %ctx.gate().as_ref()))] -pub fn convert_bits( +pub fn convert_bits<'a, F, V, C, S, VS>( ctx: C, binary_shares: VS, bit_range: Range, -) -> impl Stream, Error>> +) -> impl Stream, Error>> + 'a where F: PrimeField, - V: ToBitConversionTriples, - C: UpgradedContext, + V: ToBitConversionTriples + 'a, + C: UpgradedContext + 'a, S: LinearSecretSharing + SecureMul, - VS: Stream + Unpin + Send, + VS: Stream + Unpin + Send + 'a, for<'u> UpgradeContext<'u, C, F, RecordId>: UpgradeToMalicious<'u, BitConversionTriple>, BitConversionTriple>, { @@ -313,35 +313,37 @@ where /// Note that unconverted fields are not upgraded, so they might need to be upgraded either before or /// after invoking this function. #[tracing::instrument(name = "modulus_conversion", skip_all, fields(bits = ?bit_range, gate = %ctx.gate().as_ref()))] -pub fn convert_selected_bits( +pub fn convert_selected_bits<'inp, F, V, C, S, VS, R>( ctx: C, binary_shares: VS, bit_range: Range, -) -> impl Stream, R), Error>> +) -> impl Stream, R), Error>> + 'inp where + R: Send + 'static, F: PrimeField, - V: ToBitConversionTriples, - C: UpgradedContext, + V: ToBitConversionTriples + 'inp, + C: UpgradedContext + 'inp, S: LinearSecretSharing + SecureMul, - VS: Stream + Unpin + Send, + VS: Stream + Unpin + Send + 'inp, for<'u> UpgradeContext<'u, C, F, RecordId>: UpgradeToMalicious<'u, BitConversionTriple>, BitConversionTriple>, { convert_some_bits(ctx, binary_shares, RecordId::FIRST, bit_range) } -pub(crate) fn convert_some_bits( +pub(crate) fn convert_some_bits<'a, F, V, C, S, VS, R>( ctx: C, binary_shares: VS, first_record: RecordId, bit_range: Range, -) -> impl Stream, R), Error>> +) -> impl Stream, R), Error>> + 'a where + R: Send + 'static, F: PrimeField, - V: ToBitConversionTriples, - C: UpgradedContext, + V: ToBitConversionTriples + 'a, + C: UpgradedContext + 'a, S: LinearSecretSharing + SecureMul, - VS: Stream + Unpin + Send, + VS: Stream + Unpin + Send + 'a, for<'u> UpgradeContext<'u, C, F, RecordId>: UpgradeToMalicious<'u, BitConversionTriple>, BitConversionTriple>, { diff --git a/ipa-core/src/protocol/sort/generate_permutation_opt.rs b/ipa-core/src/protocol/sort/generate_permutation_opt.rs index b82e05a80..22d2eed1b 100644 --- a/ipa-core/src/protocol/sort/generate_permutation_opt.rs +++ b/ipa-core/src/protocol/sort/generate_permutation_opt.rs @@ -298,7 +298,9 @@ mod tests { } /// Passing 32 records for Fp31 doesn't work. - #[tokio::test] + /// + /// Requires one extra thread to cancel futures running in parallel with the one that panics. + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] #[should_panic = "prime field ipa_core::ff::prime_field::fp31::Fp31 is too small to sort 32 records"] async fn fp31_overflow() { const COUNT: usize = 32; diff --git a/ipa-core/src/secret_sharing/replicated/malicious/additive_share.rs b/ipa-core/src/secret_sharing/replicated/malicious/additive_share.rs index c2fe440c4..d92a0d000 100644 --- a/ipa-core/src/secret_sharing/replicated/malicious/additive_share.rs +++ b/ipa-core/src/secret_sharing/replicated/malicious/additive_share.rs @@ -77,7 +77,7 @@ impl LinearSecretSharing for AdditiveShare< /// when the protocol is done. This should not be used directly. #[async_trait] pub trait Downgrade: Send { - type Target: Send; + type Target: Send + 'static; async fn downgrade(self) -> UnauthorizedDowngradeWrapper; } diff --git a/ipa-core/src/secret_sharing/scheme.rs b/ipa-core/src/secret_sharing/scheme.rs index 0d2131eeb..b20348539 100644 --- a/ipa-core/src/secret_sharing/scheme.rs +++ b/ipa-core/src/secret_sharing/scheme.rs @@ -7,7 +7,7 @@ use super::{SharedValue, WeakSharedValue}; use crate::ff::{AddSub, AddSubAssign, GaloisField}; /// Secret sharing scheme i.e. Replicated secret sharing -pub trait SecretSharing: Clone + Debug + Sized + Send + Sync { +pub trait SecretSharing: Clone + Debug + Sized + Send + Sync + 'static { const ZERO: Self; } @@ -21,6 +21,7 @@ pub trait Linear: + Mul + for<'r> Mul<&'r V, Output = Self> + Neg + + 'static { } diff --git a/ipa-core/src/seq_join.rs b/ipa-core/src/seq_join.rs index e3cb8c5b3..12cab4d67 100644 --- a/ipa-core/src/seq_join.rs +++ b/ipa-core/src/seq_join.rs @@ -1,5 +1,4 @@ use std::{ - collections::VecDeque, future::IntoFuture, num::NonZeroUsize, pin::Pin, @@ -39,15 +38,13 @@ pub fn assert_send<'a, O>( /// [`try_join_all`]: futures::future::try_join_all /// [`Stream`]: futures::stream::Stream /// [`StreamExt::buffered`]: futures::stream::StreamExt::buffered -pub fn seq_join(active: NonZeroUsize, source: S) -> SequentialFutures +pub fn seq_join<'st, S, F, O>(active: NonZeroUsize, source: S) -> SequentialFutures<'st, S, F> where - S: Stream + Send, - F: Future, + S: Stream + Send + 'st, + F: Future + Send, + O: Send + 'static, { - SequentialFutures { - source: source.fuse(), - active: VecDeque::with_capacity(active.get()), - } + SequentialFutures::new(active, source) } /// The `SeqJoin` trait wraps `seq_try_join_all`, providing the `active` parameter @@ -72,16 +69,37 @@ pub trait SeqJoin { /// [`active_work`]: Self::active_work /// [`parallel_join`]: Self::parallel_join /// [`join3`]: futures::future::join3 - fn try_join(&self, iterable: I) -> TryCollect, Vec> + fn try_join<'fut, I, F, O, E>( + &self, + iterable: I, + ) -> TryCollect, Vec> where I: IntoIterator + Send, - I::IntoIter: Send, - F: Future>, + I::IntoIter: Send + 'fut, + F: Future> + Send + 'fut, + O: Send + 'static, + E: Send + 'static, { seq_try_join_all(self.active_work(), iterable) } /// Join multiple tasks in parallel. Only do this if you can't use a sequential join. + #[cfg(feature = "multi-threading")] + fn parallel_join<'a, I, F, O, E>( + &self, + iterable: I, + ) -> Pin, E>> + Send + 'a>> + where + I: IntoIterator + Send, + F: Future> + Send + 'a, + O: Send + 'static, + E: Send + 'static, + { + multi_thread::parallel_join(iterable) + } + + /// Join multiple tasks in parallel. Only do this if you can't use a sequential join. + #[cfg(not(feature = "multi-threading"))] fn parallel_join(&self, iterable: I) -> futures::future::TryJoinAll where I: IntoIterator, @@ -95,209 +113,348 @@ pub trait SeqJoin { fn active_work(&self) -> NonZeroUsize; } -type SeqTryJoinAll = SequentialFutures::IntoIter>, F>; +type SeqTryJoinAll<'st, I, F> = + SequentialFutures<'st, StreamIter<::IntoIter>, F>; /// A substitute for [`futures::future::try_join_all`] that uses [`seq_join`]. /// This awaits all the provided futures in order, /// aborting early if any future returns `Result::Err`. -pub fn seq_try_join_all( +pub fn seq_try_join_all<'iter, I, F, O, E>( active: NonZeroUsize, source: I, -) -> TryCollect, Vec> +) -> TryCollect, Vec> where I: IntoIterator + Send, - I::IntoIter: Send, - F: Future>, + I::IntoIter: Send + 'iter, + F: Future> + Send + 'iter, + O: Send + 'static, + E: Send + 'static, { seq_join(active, iter(source)).try_collect() } -enum ActiveItem { - Pending(Pin>), - Resolved(F::Output), +impl<'fut, S, F> ExactSizeStream for SequentialFutures<'fut, S, F> +where + S: Stream + Send + ExactSizeStream, + F: IntoFuture, + ::IntoFuture: Send + 'fut, + <::IntoFuture as Future>::Output: Send + 'static, +{ } -impl ActiveItem { - /// Drives this item to resolved state when value is ready to be taken out. Has no effect - /// if the value is ready. - /// - /// ## Panics - /// Panics if this item is completed - fn check_ready(&mut self, cx: &mut Context<'_>) -> bool { - let ActiveItem::Pending(f) = self else { - return true; - }; - if let Poll::Ready(v) = Future::poll(Pin::as_mut(f), cx) { - *self = ActiveItem::Resolved(v); - true - } else { - false - } - } +#[cfg(feature = "multi-threading")] +pub type SequentialFutures<'fut, S, F> = multi_thread::SequentialFutures<'fut, S, F>; - /// Takes the resolved value out - /// - /// ## Panics - /// If the value is not ready yet. - #[must_use] - fn take(self) -> F::Output { - let ActiveItem::Resolved(v) = self else { - panic!("No value to take out"); - }; +#[cfg(not(feature = "multi-threading"))] +pub type SequentialFutures<'unused, S, F> = local::SequentialFutures<'unused, S, F>; + +/// Parallel and sequential join that use at most one thread. Good for unit testing and debugging, +/// to get results in predictable order with fewer things happening at the same time. +#[cfg(not(feature = "multi-threading"))] +mod local { + use std::{collections::VecDeque, marker::PhantomData}; - v + use super::*; + + enum ActiveItem { + Pending(Pin>), + Resolved(F::Output), } -} -#[pin_project] -pub struct SequentialFutures -where - S: Stream + Send, - F: IntoFuture, -{ - #[pin] - source: futures::stream::Fuse, - active: VecDeque>, -} + impl ActiveItem { + /// Drives this item to resolved state when value is ready to be taken out. Has no effect + /// if the value is ready. + /// + /// ## Panics + /// Panics if this item is completed + fn check_ready(&mut self, cx: &mut Context<'_>) -> bool { + let ActiveItem::Pending(f) = self else { + return true; + }; + if let Poll::Ready(v) = Future::poll(Pin::as_mut(f), cx) { + *self = ActiveItem::Resolved(v); + true + } else { + false + } + } -impl Stream for SequentialFutures -where - S: Stream + Send, - F: IntoFuture, -{ - type Item = F::Output; + /// Takes the resolved value out + /// + /// ## Panics + /// If the value is not ready yet. + #[must_use] + fn take(self) -> F::Output { + let ActiveItem::Resolved(v) = self else { + panic!("No value to take out"); + }; + + v + } + } - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let mut this = self.project(); + #[pin_project] + pub struct SequentialFutures<'unused, S, F> + where + S: Stream + Send, + F: IntoFuture, + { + #[pin] + source: futures::stream::Fuse, + active: VecDeque>, + _marker: PhantomData &'unused ()>, + } - // Draw more values from the input, up to the capacity. - while this.active.len() < this.active.capacity() { - if let Poll::Ready(Some(f)) = this.source.as_mut().poll_next(cx) { - this.active - .push_back(ActiveItem::Pending(Box::pin(f.into_future()))); - } else { - break; + impl SequentialFutures<'_, S, F> + where + S: Stream + Send, + F: IntoFuture, + { + pub fn new(active: NonZeroUsize, source: S) -> Self { + Self { + source: source.fuse(), + active: VecDeque::with_capacity(active.get()), + _marker: PhantomData, } } + } - if let Some(item) = this.active.front_mut() { - if item.check_ready(cx) { - let v = this.active.pop_front().map(ActiveItem::take); - Poll::Ready(v) - } else { - for f in this.active.iter_mut().skip(1) { - f.check_ready(cx); + impl Stream for SequentialFutures<'_, S, F> + where + S: Stream + Send, + F: IntoFuture, + { + type Item = F::Output; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + + // Draw more values from the input, up to the capacity. + while this.active.len() < this.active.capacity() { + if let Poll::Ready(Some(f)) = this.source.as_mut().poll_next(cx) { + this.active + .push_back(ActiveItem::Pending(Box::pin(f.into_future()))); + } else { + break; + } + } + + if let Some(item) = this.active.front_mut() { + if item.check_ready(cx) { + let v = this.active.pop_front().map(ActiveItem::take); + Poll::Ready(v) + } else { + for f in this.active.iter_mut().skip(1) { + f.check_ready(cx); + } + Poll::Pending } + } else if this.source.is_done() { + Poll::Ready(None) + } else { Poll::Pending } - } else if this.source.is_done() { - Poll::Ready(None) - } else { - Poll::Pending } - } - fn size_hint(&self) -> (usize, Option) { - let in_progress = self.active.len(); - let (lower, upper) = self.source.size_hint(); - ( - lower.saturating_add(in_progress), - upper.and_then(|u| u.checked_add(in_progress)), - ) + fn size_hint(&self) -> (usize, Option) { + let in_progress = self.active.len(); + let (lower, upper) = self.source.size_hint(); + ( + lower.saturating_add(in_progress), + upper.and_then(|u| u.checked_add(in_progress)), + ) + } } } -impl ExactSizeStream for SequentialFutures -where - S: Stream + Send + ExactSizeStream, - F: IntoFuture, -{ -} +/// Both joins use executor tasks to drive futures to completion. Much faster than single-threaded +/// version, so this is what we want to use in release/prod mode. +#[cfg(feature = "multi-threading")] +mod multi_thread { + use futures::future::BoxFuture; + use tracing::{Instrument, Span}; -#[cfg(all(test, unit_test))] -mod test { - use std::{ - convert::Infallible, - iter::once, - num::NonZeroUsize, - ptr::null, - sync::{Arc, Mutex}, - task::{Context, Poll, Waker}, - }; + use super::*; - use futures::{ - future::{lazy, BoxFuture}, - stream::{iter, poll_fn, poll_immediate, repeat_with}, - Future, StreamExt, - }; + #[cfg(feature = "shuttle")] + mod shuttle_spawner { + use shuttle_crate::{ + future, + future::{JoinError, JoinHandle}, + }; - use crate::seq_join::{seq_join, seq_try_join_all}; + use super::*; - async fn immediate(count: u32) { - let capacity = NonZeroUsize::new(3).unwrap(); - let values = seq_join(capacity, iter((0..count).map(|i| async move { i }))) - .collect::>() - .await; - assert_eq!((0..count).collect::>(), values); - } + /// Spawner implementation for Shuttle framework to run tests in parallel + pub(super) struct ShuttleSpawner; - #[tokio::test] - async fn within_capacity() { - immediate(2).await; - immediate(1).await; + unsafe impl async_scoped::spawner::Spawner for ShuttleSpawner + where + T: Send + 'static, + { + type FutureOutput = Result; + type SpawnHandle = JoinHandle; + + fn spawn + Send + 'static>(&self, f: F) -> Self::SpawnHandle { + future::spawn(f) + } + } + + unsafe impl async_scoped::spawner::Blocker for ShuttleSpawner { + fn block_on>(&self, f: F) -> T { + future::block_on(f) + } + } } - #[tokio::test] - async fn over_capacity() { - immediate(10).await; + #[cfg(feature = "shuttle")] + type Spawner<'fut, T> = async_scoped::Scope<'fut, T, shuttle_spawner::ShuttleSpawner>; + #[cfg(not(feature = "shuttle"))] + type Spawner<'fut, T> = TokioScope<'fut, T>; + + unsafe fn create_spawner<'fut, T: Send + 'static>() -> Spawner<'fut, T> { + #[cfg(feature = "shuttle")] + return async_scoped::Scope::create(shuttle_spawner::ShuttleSpawner); + #[cfg(not(feature = "shuttle"))] + return TokioScope::create(Tokio); } - #[tokio::test] - async fn out_of_order() { - let capacity = NonZeroUsize::new(3).unwrap(); - let barrier = tokio::sync::Barrier::new(2); - let unresolved: BoxFuture<'_, u32> = Box::pin(async { - barrier.wait().await; - 0 - }); - let it = once(unresolved) - .chain((1..4_u32).map(|i| -> BoxFuture<'_, u32> { Box::pin(async move { i }) })); - let mut seq_futures = seq_join(capacity, iter(it)); + #[pin_project] + #[must_use = "Futures do nothing, unless polled"] + pub struct SequentialFutures<'fut, S, F> + where + S: Stream + Send + 'fut, + F: IntoFuture, + <::IntoFuture as Future>::Output: Send + 'static, + { + #[pin] + spawner: Spawner<'fut, F::Output>, + #[pin] + source: futures::stream::Fuse, + capacity: usize, + } - assert_eq!( - Some(Poll::Pending), - poll_immediate(&mut seq_futures).next().await - ); - barrier.wait().await; - assert_eq!(vec![0, 1, 2, 3], seq_futures.collect::>().await); + impl SequentialFutures<'_, S, F> + where + S: Stream + Send, + F: IntoFuture, + <::IntoFuture as Future>::Output: Send + 'static, + { + pub fn new(active: NonZeroUsize, source: S) -> Self { + SequentialFutures { + spawner: unsafe { create_spawner() }, + source: source.fuse(), + capacity: active.get(), + } + } } - #[tokio::test] - async fn join_success() { - fn f(v: T) -> impl Future> { - lazy(move |_| Ok(v)) + impl<'fut, S, F> Stream for SequentialFutures<'fut, S, F> + where + S: Stream + Send, + F: IntoFuture, + ::IntoFuture: Send + 'fut, + <::IntoFuture as Future>::Output: Send + 'static, + { + type Item = F::Output; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + + // Draw more values from the input, up to the capacity. + while this.spawner.remaining() < *this.capacity { + if let Poll::Ready(Some(f)) = this.source.as_mut().poll_next(cx) { + // Making futures cancellable is critical to avoid hangs. + // if one of them panics, unwinding causes spawner to drop and, in turn, + // it blocks the thread to await all pending futures completion. If there is + // a dependency between futures, pending one will never complete. + // Cancellable futures will be cancelled when spawner is dropped which is + // the behavior we want. + this.spawner + .spawn_cancellable(f.into_future().instrument(Span::current()), || { + panic!("cancelled") + }); + } else { + break; + } + } + + // Poll spawner if it has work to do. If both source and spawner are empty, we're done + if this.spawner.remaining() > 0 { + this.spawner.as_mut().poll_next(cx).map(|v| match v { + Some(Ok(v)) => Some(v), + Some(Err(_)) => panic!("task is cancelled"), + None => None, + }) + } else if this.source.is_done() { + Poll::Ready(None) + } else { + Poll::Pending + } } - let active = NonZeroUsize::new(10).unwrap(); - let res = seq_try_join_all(active, (1..5).map(f)).await.unwrap(); - assert_eq!((1..5).collect::>(), res); + fn size_hint(&self) -> (usize, Option) { + let in_progress = self.spawner.remaining(); + let (lower, upper) = self.source.size_hint(); + ( + lower.saturating_add(in_progress), + upper.and_then(|u| u.checked_add(in_progress)), + ) + } } - #[tokio::test] - async fn try_join_early_abort() { - const ERROR: &str = "error message"; - fn f(i: u32) -> impl Future> { - lazy(move |_| match i { - 1 => Ok(1), - 2 => Err(ERROR), - _ => panic!("should have aborted earlier"), - }) - } + /// TODO: change it to impl Future once https://github.com/rust-lang/rust/pull/115822 is + /// available in stable Rust. + pub(super) fn parallel_join<'fut, I, F, O, E>(iterable: I) -> BoxFuture<'fut, Result, E>> + where + I: IntoIterator + Send, + F: Future> + Send + 'fut, + O: Send + 'static, + E: Send + 'static, + { + // TODO: implement spawner for shuttle + let mut scope = { + let iter = iterable.into_iter(); + // SAFETY: scope object does not escape this function. All futures are driven to + // completion inside it or cancelled if a panic occurs. + let mut scope = unsafe { create_spawner() }; + for element in iter { + // it is important to make those cancellable. + // TODO: elaborate why + scope.spawn_cancellable(element.instrument(Span::current()), || { + panic!("Future is cancelled.") + }); + } + scope + }; - let active = NonZeroUsize::new(10).unwrap(); - let err = seq_try_join_all(active, (1..=3).map(f)).await.unwrap_err(); - assert_eq!(err, ERROR); + Box::pin(async move { + let mut result = Vec::with_capacity(scope.len()); + while let Some(item) = scope.next().await { + // join error is nothing we can do about + result.push(item.unwrap()?) + } + Ok(result) + }) } +} + +#[cfg(all(test, unit_test, not(feature = "multi-threading")))] +mod local_test { + use std::{ + num::NonZeroUsize, + ptr::null, + sync::{Arc, Mutex}, + task::{Context, Poll, Waker}, + }; + + use futures::{ + future::lazy, + stream::{poll_fn, repeat_with}, + StreamExt, + }; + + use super::*; fn fake_waker() -> Waker { use std::task::{RawWaker, RawWakerVTable}; @@ -365,8 +522,8 @@ mod test { } /// A fully synchronous test with a synthetic stream, all the way to the end. - #[test] - fn complete_stream() { + #[tokio::test] + async fn complete_stream() { const VALUE: u32 = 20; const COUNT: usize = 7; let capacity = NonZeroUsize::new(3).unwrap(); @@ -408,3 +565,106 @@ mod test { assert!(matches!(res, Poll::Ready(None))); } } + +#[cfg(all(test, unit_test))] +mod test { + use std::{convert::Infallible, iter::once}; + + use futures::{ + future::{lazy, BoxFuture}, + stream::{iter, poll_immediate}, + Future, StreamExt, + }; + + use super::*; + + async fn immediate(count: u32) { + let capacity = NonZeroUsize::new(3).unwrap(); + let values = seq_join(capacity, iter((0..count).map(|i| async move { i }))) + .collect::>() + .await; + assert_eq!((0..count).collect::>(), values); + } + + #[tokio::test] + async fn within_capacity() { + immediate(2).await; + immediate(1).await; + } + + #[tokio::test] + async fn over_capacity() { + immediate(10).await; + } + + #[tokio::test] + async fn out_of_order() { + let capacity = NonZeroUsize::new(3).unwrap(); + let barrier = tokio::sync::Barrier::new(2); + let unresolved: BoxFuture<'_, u32> = Box::pin(async { + barrier.wait().await; + 0 + }); + let it = once(unresolved) + .chain((1..4_u32).map(|i| -> BoxFuture<'_, u32> { Box::pin(async move { i }) })); + let mut seq_futures = seq_join(capacity, iter(it)); + + assert_eq!( + Some(Poll::Pending), + poll_immediate(&mut seq_futures).next().await + ); + barrier.wait().await; + assert_eq!(vec![0, 1, 2, 3], seq_futures.collect::>().await); + } + + #[tokio::test] + async fn join_success() { + fn f(v: T) -> impl Future> { + lazy(move |_| Ok(v)) + } + + let active = NonZeroUsize::new(10).unwrap(); + let res = seq_try_join_all(active, (1..5).map(f)).await.unwrap(); + assert_eq!((1..5).collect::>(), res); + } + + /// This test has to use multi-threaded runtime because early return causes `TryCollect` to be + /// dropped and the remaining futures to be cancelled which can only happen if there is more + /// than one thread available. + /// + /// This behavior is only applicable when `seq_try_join_all` uses more than one thread, for + /// maintenance reasons, we use it even parallelism is turned off. + #[tokio::test(flavor = "multi_thread")] + async fn try_join_early_abort() { + const ERROR: &str = "error message"; + fn f(i: u32) -> impl Future> { + lazy(move |_| match i { + 1 => Ok(1), + 2 => Err(ERROR), + _ => panic!("should have aborted earlier"), + }) + } + + let active = NonZeroUsize::new(10).unwrap(); + let err = seq_try_join_all(active, (1..=3).map(f)).await.unwrap_err(); + assert_eq!(err, ERROR); + } + + #[tokio::test(flavor = "multi_thread")] + async fn does_not_block_on_error() { + const ERROR: &str = "returning early is safe"; + use std::pin::Pin; + + fn f(i: u32) -> Pin> + Send>> { + match i { + 1 => Box::pin(lazy(move |_| Ok(1))), + 2 => Box::pin(lazy(move |_| Err(ERROR))), + _ => Box::pin(futures::future::pending()), + } + } + + let active = NonZeroUsize::new(10).unwrap(); + let err = seq_try_join_all(active, (1..=3).map(f)).await.unwrap_err(); + assert_eq!(err, ERROR); + } +} From 2a93f76393a538c37b581aff9f51afcb242baa6a Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Wed, 10 Jan 2024 16:39:08 -0800 Subject: [PATCH 02/27] Add a test that demonstrates the unsafety of parallel_join --- ipa-core/src/seq_join.rs | 50 ++++++++++++++++++++++++++++++++++++---- 1 file changed, 46 insertions(+), 4 deletions(-) diff --git a/ipa-core/src/seq_join.rs b/ipa-core/src/seq_join.rs index 12cab4d67..d2f189e61 100644 --- a/ipa-core/src/seq_join.rs +++ b/ipa-core/src/seq_join.rs @@ -310,13 +310,13 @@ mod multi_thread { #[cfg(feature = "shuttle")] type Spawner<'fut, T> = async_scoped::Scope<'fut, T, shuttle_spawner::ShuttleSpawner>; #[cfg(not(feature = "shuttle"))] - type Spawner<'fut, T> = TokioScope<'fut, T>; + type Spawner<'fut, T> = async_scoped::TokioScope<'fut, T>; unsafe fn create_spawner<'fut, T: Send + 'static>() -> Spawner<'fut, T> { #[cfg(feature = "shuttle")] return async_scoped::Scope::create(shuttle_spawner::ShuttleSpawner); #[cfg(not(feature = "shuttle"))] - return TokioScope::create(Tokio); + return async_scoped::TokioScope::create(); } #[pin_project] @@ -568,15 +568,16 @@ mod local_test { #[cfg(all(test, unit_test))] mod test { - use std::{convert::Infallible, iter::once}; + use std::{convert::Infallible, iter::once, sync::Arc}; use futures::{ - future::{lazy, BoxFuture}, + future::{lazy, poll_immediate as poll_immediate_fut, BoxFuture}, stream::{iter, poll_immediate}, Future, StreamExt, }; use super::*; + use crate::seq_join::multi_thread::parallel_join; async fn immediate(count: u32) { let capacity = NonZeroUsize::new(3).unwrap(); @@ -667,4 +668,45 @@ mod test { let err = seq_try_join_all(active, (1..=3).map(f)).await.unwrap_err(); assert_eq!(err, ERROR); } + + /// This test demonstrates that forgetting the future returned by `parallel_join` is not safe and will cause + /// use-after-free safety error. + #[tokio::test(flavor = "multi_thread")] + #[ignore] // sanitizers will flag this test + async fn parallel_join_forget_is_not_safe() { + const N: usize = 24; + let borrow_from_me = vec![1, 2, 3]; + let barrier1 = Arc::new(tokio::sync::Barrier::new(N + 1)); + let barrier2 = Arc::new(tokio::sync::Barrier::new(N + 1)); + + let iterable = (0..N) + .map(|i| { + let borrowed = &borrow_from_me; + let b1 = barrier1.clone(); + let b2 = barrier2.clone(); + async move { + b1.wait().await; + for _ in 0..100 { + if borrowed != &vec![1, 2, 3] { + panic!("corruption inside task {i}: {borrowed:?} != [1, 2, 3]") + } + tokio::time::sleep(std::time::Duration::from_millis(10)).await; + } + b2.wait().await; + Ok::<(), ()>(()) + } + }) + .collect::>(); + + let mut f = parallel_join(iterable); + poll_immediate_fut(&mut f).await; + barrier1.wait().await; + + // forgetting f does not mean that futures spawned by `parallel_join` will be cancelled. + std::mem::forget(f); + + // Async executor will still be polling futures that borrow this vector and this will cause use-after-free. + drop(borrow_from_me); + barrier2.wait().await; + } } From e1bbdc9c7154cdb4c22a9a5a375df3a83cd1b7f5 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Wed, 10 Jan 2024 17:28:16 -0800 Subject: [PATCH 03/27] Fix compile errors --- ipa-core/src/seq_join.rs | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/ipa-core/src/seq_join.rs b/ipa-core/src/seq_join.rs index 045056adf..3d80c248c 100644 --- a/ipa-core/src/seq_join.rs +++ b/ipa-core/src/seq_join.rs @@ -568,10 +568,10 @@ mod local_test { #[cfg(all(test, unit_test))] mod test { - use std::{convert::Infallible, iter::once, sync::Arc}; + use std::{convert::Infallible, iter::once}; use futures::{ - future::{lazy, poll_immediate as poll_immediate_fut, BoxFuture}, + future::{lazy, BoxFuture}, stream::{iter, poll_immediate}, Future, StreamExt, }; @@ -674,7 +674,10 @@ mod test { #[cfg(feature = "multi-threading")] #[ignore] // sanitizers will flag this test async fn parallel_join_forget_is_not_safe() { - use crate::seq_join::multi_thread::parallel_join; + use futures::future::poll_immediate; + + use crate::{seq_join::multi_thread::parallel_join, sync::Arc}; + const N: usize = 24; let borrow_from_me = vec![1, 2, 3]; let barrier1 = Arc::new(tokio::sync::Barrier::new(N + 1)); @@ -700,7 +703,7 @@ mod test { .collect::>(); let mut f = parallel_join(iterable); - poll_immediate_fut(&mut f).await; + poll_immediate(&mut f).await; barrier1.wait().await; // forgetting f does not mean that futures spawned by `parallel_join` will be cancelled. From 69655b79944e06bf75a182f893a7c6d8e9d6afa2 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Wed, 10 Jan 2024 17:46:48 -0800 Subject: [PATCH 04/27] Remove the false safety claim from parallel_join --- ipa-core/src/seq_join.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/ipa-core/src/seq_join.rs b/ipa-core/src/seq_join.rs index 3d80c248c..a96d0694a 100644 --- a/ipa-core/src/seq_join.rs +++ b/ipa-core/src/seq_join.rs @@ -415,8 +415,6 @@ mod multi_thread { // TODO: implement spawner for shuttle let mut scope = { let iter = iterable.into_iter(); - // SAFETY: scope object does not escape this function. All futures are driven to - // completion inside it or cancelled if a panic occurs. let mut scope = unsafe { create_spawner() }; for element in iter { // it is important to make those cancellable. From a39903f6c19f2eee23921d277a9df3afdb78c52e Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Thu, 11 Jan 2024 10:38:53 -0800 Subject: [PATCH 05/27] Make `parallel_join_forget_is_not_safe` safe for sanitizers --- ipa-core/src/seq_join.rs | 38 ++++++++++++++++++++++---------------- 1 file changed, 22 insertions(+), 16 deletions(-) diff --git a/ipa-core/src/seq_join.rs b/ipa-core/src/seq_join.rs index a96d0694a..346731301 100644 --- a/ipa-core/src/seq_join.rs +++ b/ipa-core/src/seq_join.rs @@ -668,33 +668,37 @@ mod test { /// This test demonstrates that forgetting the future returned by `parallel_join` is not safe and will cause /// use-after-free safety error. + /// + /// TODO: Run tests with multi-threading runtimes in CI #[tokio::test(flavor = "multi_thread")] #[cfg(feature = "multi-threading")] - #[ignore] // sanitizers will flag this test async fn parallel_join_forget_is_not_safe() { use futures::future::poll_immediate; use crate::{seq_join::multi_thread::parallel_join, sync::Arc}; const N: usize = 24; - let borrow_from_me = vec![1, 2, 3]; - let barrier1 = Arc::new(tokio::sync::Barrier::new(N + 1)); - let barrier2 = Arc::new(tokio::sync::Barrier::new(N + 1)); + let borrow_from_me = Arc::new(vec![1, 2, 3]); + let start = Arc::new(tokio::sync::Barrier::new(N + 1)); + // counts how many tasks have accessed `borrow_from_me` after it was destroyed. + // this test expects all tasks to access `borrow_from_me` at least once. + let bad_accesses = Arc::new(tokio::sync::Barrier::new(N + 1)); let iterable = (0..N) - .map(|i| { - let borrowed = &borrow_from_me; - let b1 = barrier1.clone(); - let b2 = barrier2.clone(); + .map(|_| { + let borrowed = Arc::downgrade(&borrow_from_me); + let start = start.clone(); + let bad_access = bad_accesses.clone(); async move { - b1.wait().await; + start.wait().await; + // at this point, the parent future is forgotten and borrowed should point to nothing for _ in 0..100 { - if borrowed != &vec![1, 2, 3] { - panic!("corruption inside task {i}: {borrowed:?} != [1, 2, 3]") + if borrowed.upgrade().is_none() { + bad_access.wait().await; + break; } - tokio::time::sleep(std::time::Duration::from_millis(10)).await; + tokio::task::yield_now().await; } - b2.wait().await; Ok::<(), ()>(()) } }) @@ -702,13 +706,15 @@ mod test { let mut f = parallel_join(iterable); poll_immediate(&mut f).await; - barrier1.wait().await; + start.wait().await; // forgetting f does not mean that futures spawned by `parallel_join` will be cancelled. std::mem::forget(f); - // Async executor will still be polling futures that borrow this vector and this will cause use-after-free. + // Async executor will still be polling futures and they will try to follow this pointer. drop(borrow_from_me); - barrier2.wait().await; + + // this test should terminate because all tasks should access `borrow_from_me` at least once. + bad_accesses.wait().await; } } From 9c22e495988dcc36f784d67a97a366d45fb65e45 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Thu, 18 Jan 2024 13:45:27 -0800 Subject: [PATCH 06/27] Import async_scoped from git Version that we need hasn't been published yet. I want to unblock our efforts to deploy multi-threading --- ipa-core/Cargo.toml | 3 ++- ipa-core/src/seq_join.rs | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/ipa-core/Cargo.toml b/ipa-core/Cargo.toml index 54437fd7d..55eb22847 100644 --- a/ipa-core/Cargo.toml +++ b/ipa-core/Cargo.toml @@ -75,7 +75,8 @@ ipa-macros = { version = "*", path = "../ipa-macros" } aes = "0.8.3" async-trait = "0.1.68" -async-scoped = { version = "0.7.1", features = ["use-tokio"], optional = true } +# TODO: migrate to crates.io once 0.9 is released: https://github.com/rmanoka/async-scoped/issues/27 +async-scoped = { git = "https://github.com/rmanoka/async-scoped.git", features = ["use-tokio"], optional = true } axum = { version = "0.5.17", optional = true, features = ["http2"] } axum-server = { version = "0.5.1", optional = true, features = [ "rustls", diff --git a/ipa-core/src/seq_join.rs b/ipa-core/src/seq_join.rs index 346731301..5849ab0d9 100644 --- a/ipa-core/src/seq_join.rs +++ b/ipa-core/src/seq_join.rs @@ -316,7 +316,7 @@ mod multi_thread { #[cfg(feature = "shuttle")] return async_scoped::Scope::create(shuttle_spawner::ShuttleSpawner); #[cfg(not(feature = "shuttle"))] - return async_scoped::TokioScope::create(); + return async_scoped::TokioScope::create(async_scoped::spawner::use_tokio::Tokio); } #[pin_project] From 98c59d2a1a1021bbd7d54da66ccbbb22923c7469 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Thu, 18 Jan 2024 15:03:04 -0800 Subject: [PATCH 07/27] Update CI to run multi-threading tests --- .github/workflows/check.yml | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/.github/workflows/check.yml b/.github/workflows/check.yml index aba7c292a..ab05bdc1e 100644 --- a/.github/workflows/check.yml +++ b/.github/workflows/check.yml @@ -63,9 +63,12 @@ jobs: if: ${{ success() || failure() }} run: cargo build --tests - - name: Run Tests + - name: Run tests run: cargo test + - name: Run tests with multithreading feature enabled + run: cargo test --feature "multi-threading" + - name: Run Web Tests run: cargo test -p ipa-core --no-default-features --features "cli web-app real-world-infra test-fixture descriptive-gate" @@ -96,10 +99,10 @@ jobs: run: cargo build --release - name: Build concurrency tests - run: cargo build --release --features shuttle + run: cargo build --release --features "shuttle multi-threading" - name: Run concurrency tests - run: cargo test --release --features shuttle + run: cargo test --release --features "shuttle multi-threading" extra: name: Additional Builds and Concurrency Tests @@ -148,6 +151,7 @@ jobs: fail-fast: false matrix: sanitizer: [address, leak] + features: ['', 'multi-threading'] env: TARGET: x86_64-unknown-linux-gnu steps: @@ -156,7 +160,7 @@ jobs: - name: Add Rust sources run: rustup component add rust-src - name: Run tests with sanitizer - run: RUSTFLAGS="-Z sanitizer=${{ matrix.sanitizer }} -Z sanitizer-memory-track-origins" cargo test -Z build-std --target $TARGET --no-default-features --features "cli web-app real-world-infra test-fixture descriptive-gate" + run: RUSTFLAGS="-Z sanitizer=${{ matrix.sanitizer }} -Z sanitizer-memory-track-origins" cargo test -Z build-std --target $TARGET --no-default-features --features "cli web-app real-world-infra test-fixture descriptive-gate ${{ matrix.features }}" coverage: name: Measure coverage From 61d8560ca715467944442dcb8067cb3b134f0d84 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Thu, 18 Jan 2024 15:09:43 -0800 Subject: [PATCH 08/27] Fix a typo in checks.yml --- .github/workflows/check.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/check.yml b/.github/workflows/check.yml index ab05bdc1e..d7ba76bac 100644 --- a/.github/workflows/check.yml +++ b/.github/workflows/check.yml @@ -67,7 +67,7 @@ jobs: run: cargo test - name: Run tests with multithreading feature enabled - run: cargo test --feature "multi-threading" + run: cargo test --features "multi-threading" - name: Run Web Tests run: cargo test -p ipa-core --no-default-features --features "cli web-app real-world-infra test-fixture descriptive-gate" From 105c76ae1deec2ca35c15a08286af4446911b3d6 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Thu, 18 Jan 2024 16:23:00 -0800 Subject: [PATCH 09/27] Prohibit methods that can leak data seq_join won't be happy with them anymore --- .clippy.toml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.clippy.toml b/.clippy.toml index 238592639..5c572f532 100644 --- a/.clippy.toml +++ b/.clippy.toml @@ -2,4 +2,8 @@ disallowed-methods = [ { path = "futures::future::join_all", reason = "We don't have a replacement for this method yet. Consider extending `SeqJoin` trait." }, { path = "futures::future::try_join_all", reason = "Use Context.try_join instead." }, + { path = "std::boxed::Box::leak", reason = "Not running the destructors on futures created inside seq_join module will cause UB in IPA. Make sure you don't leak any of those." }, + { path = "std::mem::forget", reason = "Not running the destructors on futures created inside seq_join module will cause UB in IPA. Make sure you don't leak any of those." }, + { path = "std::mem::ManuallyDrop::new", reason = "Not running the destructors on futures created inside seq_join module will cause UB in IPA. Make sure you don't leak any of those." }, + { path = "std::vec::Vec::leak", reason = "Not running the destructors on futures created inside seq_join module will cause UB in IPA. Make sure you don't leak any of those." }, ] From a2c2f6fcf2340057eff38837076691eb4a498308 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Thu, 18 Jan 2024 17:30:50 -0800 Subject: [PATCH 10/27] Disable Shuttle tests for IPA with multi-threading enabled They fail with too many steps issue, so we likely can't make it work for the whole protocol. --- ipa-core/src/protocol/ipa/mod.rs | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/ipa-core/src/protocol/ipa/mod.rs b/ipa-core/src/protocol/ipa/mod.rs index b59092bf9..8459d1a8a 100644 --- a/ipa-core/src/protocol/ipa/mod.rs +++ b/ipa-core/src/protocol/ipa/mod.rs @@ -466,7 +466,16 @@ where .collect::>() } -#[cfg(all(test, any(unit_test, feature = "shuttle")))] +#[cfg(all( + test, + any( + unit_test, + all( + any(feature = "shuttle", feature = "multi-threading"), + not(all(feature = "shuttle", feature = "multi-threading")) + ) + ) +))] pub mod tests { use std::num::NonZeroU32; From 12792788a916cad043203ab4cca74797158c8c65 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Thu, 18 Jan 2024 17:55:33 -0800 Subject: [PATCH 11/27] Fix a bug in ipa tests conditional compilation gate --- ipa-core/src/protocol/ipa/mod.rs | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/ipa-core/src/protocol/ipa/mod.rs b/ipa-core/src/protocol/ipa/mod.rs index 8459d1a8a..a6e541713 100644 --- a/ipa-core/src/protocol/ipa/mod.rs +++ b/ipa-core/src/protocol/ipa/mod.rs @@ -468,13 +468,7 @@ where #[cfg(all( test, - any( - unit_test, - all( - any(feature = "shuttle", feature = "multi-threading"), - not(all(feature = "shuttle", feature = "multi-threading")) - ) - ) + any(unit_test, all(feature = "shuttle", not(feature = "multi-threading"))) ))] pub mod tests { use std::num::NonZeroU32; From a19c71658e41c9318abd454cc47b392b8021a8fb Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Fri, 19 Jan 2024 10:49:17 -0800 Subject: [PATCH 12/27] Upgrade tokio to allow Miri We need this fix: https://github.com/tokio-rs/tokio/pull/6179 --- ipa-core/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ipa-core/Cargo.toml b/ipa-core/Cargo.toml index 55eb22847..af0d0d337 100644 --- a/ipa-core/Cargo.toml +++ b/ipa-core/Cargo.toml @@ -136,7 +136,7 @@ sha2 = "0.10" shuttle-crate = { package = "shuttle", version = "0.6.1", optional = true } thiserror = "1.0" time = { version = "0.3", optional = true } -tokio = { version = "1.28", features = ["rt", "rt-multi-thread", "macros"] } +tokio = { version = "1.35", features = ["rt", "rt-multi-thread", "macros"] } # TODO: axum-server holds onto 0.24 and we can't upgrade until they do. Or we move away from axum-server tokio-rustls = { version = "0.24", optional = true } tokio-stream = "0.1.14" From e355c39f0159fc21fe10a38ea0f236e5a8a354f6 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Fri, 19 Jan 2024 11:09:18 -0800 Subject: [PATCH 13/27] Disable Tokio IO driver Miri does not support some operations that it currently does. We don't do IO inside in-memory tests, so that should be fine. Miri progress on supporting these is tracked [here](https://github.com/rust-lang/miri/issues/2057) --- ipa-core/src/lib.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/ipa-core/src/lib.rs b/ipa-core/src/lib.rs index 0e4f63828..7082448fd 100644 --- a/ipa-core/src/lib.rs +++ b/ipa-core/src/lib.rs @@ -118,7 +118,10 @@ pub(crate) mod test_executor { Fut: Future, { tokio::runtime::Builder::new_multi_thread() - .enable_all() + // IO driver is disabled to run our tests under Miri. If you need it, make sure you + // annotate this test with #[cfg(not(miri))] + // https://github.com/rust-lang/miri/issues/2057 + .enable_time() .build() .unwrap() .block_on(f()); From e47b15a9f99a74b738675d4958f7687db6cf5bd3 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Fri, 19 Jan 2024 11:10:09 -0800 Subject: [PATCH 14/27] Make seq_join tests Miri compatible --- ipa-core/src/seq_join.rs | 104 +++++++++++++++++++++------------------ 1 file changed, 57 insertions(+), 47 deletions(-) diff --git a/ipa-core/src/seq_join.rs b/ipa-core/src/seq_join.rs index 5849ab0d9..4f368ade5 100644 --- a/ipa-core/src/seq_join.rs +++ b/ipa-core/src/seq_join.rs @@ -575,6 +575,7 @@ mod test { }; use super::*; + use crate::test_executor::run; async fn immediate(count: u32) { let capacity = NonZeroUsize::new(3).unwrap(); @@ -632,8 +633,8 @@ mod test { /// /// This behavior is only applicable when `seq_try_join_all` uses more than one thread, for /// maintenance reasons, we use it even parallelism is turned off. - #[tokio::test(flavor = "multi_thread")] - async fn try_join_early_abort() { + #[test] + fn try_join_early_abort() { const ERROR: &str = "error message"; fn f(i: u32) -> impl Future> { lazy(move |_| match i { @@ -643,13 +644,15 @@ mod test { }) } - let active = NonZeroUsize::new(10).unwrap(); - let err = seq_try_join_all(active, (1..=3).map(f)).await.unwrap_err(); - assert_eq!(err, ERROR); + run(|| async { + let active = NonZeroUsize::new(10).unwrap(); + let err = seq_try_join_all(active, (1..=3).map(f)).await.unwrap_err(); + assert_eq!(err, ERROR); + }); } - #[tokio::test(flavor = "multi_thread")] - async fn does_not_block_on_error() { + #[test] + fn does_not_block_on_error() { const ERROR: &str = "returning early is safe"; use std::pin::Pin; @@ -661,60 +664,67 @@ mod test { } } - let active = NonZeroUsize::new(10).unwrap(); - let err = seq_try_join_all(active, (1..=3).map(f)).await.unwrap_err(); - assert_eq!(err, ERROR); + run(|| async { + let active = NonZeroUsize::new(10).unwrap(); + let err = seq_try_join_all(active, (1..=3).map(f)).await.unwrap_err(); + assert_eq!(err, ERROR); + }); } /// This test demonstrates that forgetting the future returned by `parallel_join` is not safe and will cause /// use-after-free safety error. - /// - /// TODO: Run tests with multi-threading runtimes in CI - #[tokio::test(flavor = "multi_thread")] + #[test] #[cfg(feature = "multi-threading")] - async fn parallel_join_forget_is_not_safe() { + fn parallel_join_forget_is_not_safe() { + use std::mem::ManuallyDrop; + use futures::future::poll_immediate; use crate::{seq_join::multi_thread::parallel_join, sync::Arc}; - const N: usize = 24; - let borrow_from_me = Arc::new(vec![1, 2, 3]); - let start = Arc::new(tokio::sync::Barrier::new(N + 1)); - // counts how many tasks have accessed `borrow_from_me` after it was destroyed. - // this test expects all tasks to access `borrow_from_me` at least once. - let bad_accesses = Arc::new(tokio::sync::Barrier::new(N + 1)); - - let iterable = (0..N) - .map(|_| { - let borrowed = Arc::downgrade(&borrow_from_me); - let start = start.clone(); - let bad_access = bad_accesses.clone(); - async move { - start.wait().await; - // at this point, the parent future is forgotten and borrowed should point to nothing - for _ in 0..100 { - if borrowed.upgrade().is_none() { - bad_access.wait().await; - break; + run(|| async { + const N: usize = 24; + let borrow_from_me = Arc::new(vec![1, 2, 3]); + let start = Arc::new(tokio::sync::Barrier::new(N + 1)); + // counts how many tasks have accessed `borrow_from_me` after it was destroyed. + // this test expects all tasks to access `borrow_from_me` at least once. + let bad_accesses = Arc::new(tokio::sync::Barrier::new(N + 1)); + + let futures = (0..N) + .map(|_| { + let borrowed = Arc::downgrade(&borrow_from_me); + let start = start.clone(); + let bad_access = bad_accesses.clone(); + async move { + start.wait().await; + // at this point, the parent future is forgotten and borrowed should point to nothing + for _ in 0..100 { + if borrowed.upgrade().is_none() { + bad_access.wait().await; + break; + } + tokio::task::yield_now().await; } - tokio::task::yield_now().await; + Ok::<(), ()>(()) } - Ok::<(), ()>(()) - } - }) - .collect::>(); + }) + .collect::>(); + + let mut f = parallel_join(futures); + poll_immediate(&mut f).await; + start.wait().await; - let mut f = parallel_join(iterable); - poll_immediate(&mut f).await; - start.wait().await; + // forgetting f does not mean that futures spawned by `parallel_join` will be cancelled. + let guard = ManuallyDrop::new(f); - // forgetting f does not mean that futures spawned by `parallel_join` will be cancelled. - std::mem::forget(f); + // Async executor will still be polling futures and they will try to follow this pointer. + drop(borrow_from_me); - // Async executor will still be polling futures and they will try to follow this pointer. - drop(borrow_from_me); + // this test should terminate because all tasks should access `borrow_from_me` at least once. + bad_accesses.wait().await; - // this test should terminate because all tasks should access `borrow_from_me` at least once. - bad_accesses.wait().await; + // do not leak memory + let _ = ManuallyDrop::into_inner(guard); + }) } } From 2fcd3ddfc4b70387ddca0997b37691da5130b262 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Fri, 19 Jan 2024 11:12:59 -0800 Subject: [PATCH 15/27] Run Miri in CI --- .github/workflows/check.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/check.yml b/.github/workflows/check.yml index d7ba76bac..561856f15 100644 --- a/.github/workflows/check.yml +++ b/.github/workflows/check.yml @@ -158,9 +158,11 @@ jobs: - uses: actions/checkout@v3 - uses: dtolnay/rust-toolchain@nightly - name: Add Rust sources - run: rustup component add rust-src + run: rustup component add rust-src miri - name: Run tests with sanitizer run: RUSTFLAGS="-Z sanitizer=${{ matrix.sanitizer }} -Z sanitizer-memory-track-origins" cargo test -Z build-std --target $TARGET --no-default-features --features "cli web-app real-world-infra test-fixture descriptive-gate ${{ matrix.features }}" + - name: Run seq_join tests with Miri + run: cargo test --target $TARGET --lib seq_join --features "multi-threading" coverage: name: Measure coverage From e95bc42384ad3ad8df2f680b4ece55ee7c18cd0c Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Fri, 19 Jan 2024 11:32:34 -0800 Subject: [PATCH 16/27] Fix Miri action --- .github/workflows/check.yml | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/.github/workflows/check.yml b/.github/workflows/check.yml index 561856f15..a9d0d6c55 100644 --- a/.github/workflows/check.yml +++ b/.github/workflows/check.yml @@ -161,8 +161,16 @@ jobs: run: rustup component add rust-src miri - name: Run tests with sanitizer run: RUSTFLAGS="-Z sanitizer=${{ matrix.sanitizer }} -Z sanitizer-memory-track-origins" cargo test -Z build-std --target $TARGET --no-default-features --features "cli web-app real-world-infra test-fixture descriptive-gate ${{ matrix.features }}" - - name: Run seq_join tests with Miri - run: cargo test --target $TARGET --lib seq_join --features "multi-threading" + + miri: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: dtolnay/rust-toolchain@nightly + - name: Add Miri + run: rustup component add miri + - name: Run seq_join tests with Miri + run: cargo miri test --target $TARGET --lib seq_join --features "multi-threading" coverage: name: Measure coverage From 83334bceda1bb2847421f71589eeff78aab04315 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Fri, 19 Jan 2024 11:34:42 -0800 Subject: [PATCH 17/27] More fixes The beating will continue until morale improves --- .github/workflows/check.yml | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/.github/workflows/check.yml b/.github/workflows/check.yml index a9d0d6c55..1d01290d8 100644 --- a/.github/workflows/check.yml +++ b/.github/workflows/check.yml @@ -164,12 +164,16 @@ jobs: miri: runs-on: ubuntu-latest + env: + TARGET: x86_64-unknown-linux-gnu steps: - uses: actions/checkout@v3 - uses: dtolnay/rust-toolchain@nightly - name: Add Miri run: rustup component add miri - - name: Run seq_join tests with Miri + - name: Setup Miri + run: cargo miri setup + - name: Run seq_join tests run: cargo miri test --target $TARGET --lib seq_join --features "multi-threading" coverage: From 7e964fc556dc9a67f75415850f09b481a7174f5e Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Fri, 19 Jan 2024 11:47:58 -0800 Subject: [PATCH 18/27] More seq_join tests compatible with Miri --- ipa-core/src/seq_join.rs | 64 ++++++++++++++++++++++------------------ 1 file changed, 36 insertions(+), 28 deletions(-) diff --git a/ipa-core/src/seq_join.rs b/ipa-core/src/seq_join.rs index 4f368ade5..f98ad2d36 100644 --- a/ipa-core/src/seq_join.rs +++ b/ipa-core/src/seq_join.rs @@ -585,46 +585,54 @@ mod test { assert_eq!((0..count).collect::>(), values); } - #[tokio::test] - async fn within_capacity() { - immediate(2).await; - immediate(1).await; + #[test] + fn within_capacity() { + run(|| async { + immediate(2).await; + immediate(1).await; + }); } - #[tokio::test] - async fn over_capacity() { - immediate(10).await; + #[test] + fn over_capacity() { + run(|| async { + immediate(10).await; + }); } - #[tokio::test] - async fn out_of_order() { - let capacity = NonZeroUsize::new(3).unwrap(); - let barrier = tokio::sync::Barrier::new(2); - let unresolved: BoxFuture<'_, u32> = Box::pin(async { + #[test] + fn out_of_order() { + run(|| async { + let capacity = NonZeroUsize::new(3).unwrap(); + let barrier = tokio::sync::Barrier::new(2); + let unresolved: BoxFuture<'_, u32> = Box::pin(async { + barrier.wait().await; + 0 + }); + let it = once(unresolved) + .chain((1..4_u32).map(|i| -> BoxFuture<'_, u32> { Box::pin(async move { i }) })); + let mut seq_futures = seq_join(capacity, iter(it)); + + assert_eq!( + Some(Poll::Pending), + poll_immediate(&mut seq_futures).next().await + ); barrier.wait().await; - 0 + assert_eq!(vec![0, 1, 2, 3], seq_futures.collect::>().await); }); - let it = once(unresolved) - .chain((1..4_u32).map(|i| -> BoxFuture<'_, u32> { Box::pin(async move { i }) })); - let mut seq_futures = seq_join(capacity, iter(it)); - - assert_eq!( - Some(Poll::Pending), - poll_immediate(&mut seq_futures).next().await - ); - barrier.wait().await; - assert_eq!(vec![0, 1, 2, 3], seq_futures.collect::>().await); } - #[tokio::test] - async fn join_success() { + #[test] + fn join_success() { fn f(v: T) -> impl Future> { lazy(move |_| Ok(v)) } - let active = NonZeroUsize::new(10).unwrap(); - let res = seq_try_join_all(active, (1..5).map(f)).await.unwrap(); - assert_eq!((1..5).collect::>(), res); + run(|| async { + let active = NonZeroUsize::new(10).unwrap(); + let res = seq_try_join_all(active, (1..5).map(f)).await.unwrap(); + assert_eq!((1..5).collect::>(), res); + }); } /// This test has to use multi-threaded runtime because early return causes `TryCollect` to be From 44e3e3e47ee212ce0e021384bc13d45e47386953 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Fri, 19 Jan 2024 12:16:53 -0800 Subject: [PATCH 19/27] Improve code coverage --- ipa-core/src/seq_join.rs | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/ipa-core/src/seq_join.rs b/ipa-core/src/seq_join.rs index f98ad2d36..accc595b4 100644 --- a/ipa-core/src/seq_join.rs +++ b/ipa-core/src/seq_join.rs @@ -600,6 +600,21 @@ mod test { }); } + #[test] + fn size() { + run(|| async { + let mut count = 10_usize; + let capacity = NonZeroUsize::new(3).unwrap(); + let mut values = seq_join(capacity, iter((0..count).map(|i| async move { i }))); + assert_eq!((count, Some(count)), values.size_hint()); + + while values.next().await.is_some() { + count -= 1; + assert_eq!((count, Some(count)), values.size_hint()); + } + }); + } + #[test] fn out_of_order() { run(|| async { From 49e14ea38f305b8c1dfb6e733d9d5336578c302a Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Fri, 19 Jan 2024 12:31:53 -0800 Subject: [PATCH 20/27] Improve documentation and lift the unsafe blocks --- ipa-core/src/seq_join.rs | 31 ++++++++++++++++++++++++------- 1 file changed, 24 insertions(+), 7 deletions(-) diff --git a/ipa-core/src/seq_join.rs b/ipa-core/src/seq_join.rs index accc595b4..799b4a5b4 100644 --- a/ipa-core/src/seq_join.rs +++ b/ipa-core/src/seq_join.rs @@ -35,6 +35,10 @@ pub fn assert_send<'a, O>( /// This will fail to resolve if the progress of any future depends on a future more /// than `active` items behind it in the input sequence. /// +/// # Safety +/// If multi-threading is enabled, forgetting the resulting future will cause use-after-free error. Do not leak it or +/// prevent the future destructor from running. +/// /// [`try_join_all`]: futures::future::try_join_all /// [`Stream`]: futures::stream::Stream /// [`StreamExt::buffered`]: futures::stream::StreamExt::buffered @@ -44,6 +48,11 @@ where F: Future + Send, O: Send + 'static, { + #[cfg(feature = "multi-threading")] + unsafe { + SequentialFutures::new(active, source) + } + #[cfg(not(feature = "multi-threading"))] SequentialFutures::new(active, source) } @@ -84,6 +93,12 @@ pub trait SeqJoin { } /// Join multiple tasks in parallel. Only do this if you can't use a sequential join. + /// + /// # Safety + /// Forgetting the future returned from this function will cause use-after-free. This is a tradeoff between + /// performance and safety that allows us to use regular references instead of Arc pointers. + /// + /// Dropping the future is always safe. #[cfg(feature = "multi-threading")] fn parallel_join<'a, I, F, O, E>( &self, @@ -95,7 +110,7 @@ pub trait SeqJoin { O: Send + 'static, E: Send + 'static, { - multi_thread::parallel_join(iterable) + unsafe { multi_thread::parallel_join(iterable) } } /// Join multiple tasks in parallel. Only do this if you can't use a sequential join. @@ -340,7 +355,7 @@ mod multi_thread { F: IntoFuture, <::IntoFuture as Future>::Output: Send + 'static, { - pub fn new(active: NonZeroUsize, source: S) -> Self { + pub unsafe fn new(active: NonZeroUsize, source: S) -> Self { SequentialFutures { spawner: unsafe { create_spawner() }, source: source.fuse(), @@ -405,20 +420,22 @@ mod multi_thread { /// TODO: change it to impl Future once https://github.com/rust-lang/rust/pull/115822 is /// available in stable Rust. - pub(super) fn parallel_join<'fut, I, F, O, E>(iterable: I) -> BoxFuture<'fut, Result, E>> + pub(super) unsafe fn parallel_join<'fut, I, F, O, E>( + iterable: I, + ) -> BoxFuture<'fut, Result, E>> where I: IntoIterator + Send, F: Future> + Send + 'fut, O: Send + 'static, E: Send + 'static, { - // TODO: implement spawner for shuttle let mut scope = { let iter = iterable.into_iter(); let mut scope = unsafe { create_spawner() }; for element in iter { - // it is important to make those cancellable. - // TODO: elaborate why + // it is important to make those cancellable to avoid deadlocks if one of the spawned future panics. + // If there is a dependency between futures, pending one will never complete. + // Cancellable futures will be cancelled when spawner is dropped which is the behavior we want. scope.spawn_cancellable(element.instrument(Span::current()), || { panic!("Future is cancelled.") }); @@ -733,7 +750,7 @@ mod test { }) .collect::>(); - let mut f = parallel_join(futures); + let mut f = unsafe { parallel_join(futures) }; poll_immediate(&mut f).await; start.wait().await; From a40d8cd84fa2918dead963947458ea26388bd5ee Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Fri, 19 Jan 2024 13:33:55 -0800 Subject: [PATCH 21/27] Fiddle with boxed futures --- ipa-core/src/seq_join.rs | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/ipa-core/src/seq_join.rs b/ipa-core/src/seq_join.rs index 799b4a5b4..af9601da4 100644 --- a/ipa-core/src/seq_join.rs +++ b/ipa-core/src/seq_join.rs @@ -110,7 +110,7 @@ pub trait SeqJoin { O: Send + 'static, E: Send + 'static, { - unsafe { multi_thread::parallel_join(iterable) } + unsafe { Box::pin(multi_thread::parallel_join(iterable)) } } /// Join multiple tasks in parallel. Only do this if you can't use a sequential join. @@ -286,7 +286,6 @@ mod local { /// version, so this is what we want to use in release/prod mode. #[cfg(feature = "multi-threading")] mod multi_thread { - use futures::future::BoxFuture; use tracing::{Instrument, Span}; use super::*; @@ -418,11 +417,9 @@ mod multi_thread { } } - /// TODO: change it to impl Future once https://github.com/rust-lang/rust/pull/115822 is - /// available in stable Rust. pub(super) unsafe fn parallel_join<'fut, I, F, O, E>( iterable: I, - ) -> BoxFuture<'fut, Result, E>> + ) -> impl Future, E>> + Send + 'fut where I: IntoIterator + Send, F: Future> + Send + 'fut, @@ -443,14 +440,14 @@ mod multi_thread { scope }; - Box::pin(async move { + async move { let mut result = Vec::with_capacity(scope.len()); while let Some(item) = scope.next().await { // join error is nothing we can do about result.push(item.unwrap()?) } Ok(result) - }) + } } } @@ -750,7 +747,7 @@ mod test { }) .collect::>(); - let mut f = unsafe { parallel_join(futures) }; + let mut f = Box::pin(unsafe { parallel_join(futures) }); poll_immediate(&mut f).await; start.wait().await; From e1c9d176461f9b847b304d7b6ffbde34e90ca459 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Fri, 19 Jan 2024 13:39:59 -0800 Subject: [PATCH 22/27] Final cleanup --- .github/workflows/check.yml | 2 +- .../src/protocol/modulus_conversion/convert_shares.rs | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/check.yml b/.github/workflows/check.yml index 1d01290d8..6cc659bd0 100644 --- a/.github/workflows/check.yml +++ b/.github/workflows/check.yml @@ -158,7 +158,7 @@ jobs: - uses: actions/checkout@v3 - uses: dtolnay/rust-toolchain@nightly - name: Add Rust sources - run: rustup component add rust-src miri + run: rustup component add rust-src - name: Run tests with sanitizer run: RUSTFLAGS="-Z sanitizer=${{ matrix.sanitizer }} -Z sanitizer-memory-track-origins" cargo test -Z build-std --target $TARGET --no-default-features --features "cli web-app real-world-infra test-fixture descriptive-gate ${{ matrix.features }}" diff --git a/ipa-core/src/protocol/modulus_conversion/convert_shares.rs b/ipa-core/src/protocol/modulus_conversion/convert_shares.rs index e3fa765fd..ac6aabf33 100644 --- a/ipa-core/src/protocol/modulus_conversion/convert_shares.rs +++ b/ipa-core/src/protocol/modulus_conversion/convert_shares.rs @@ -320,18 +320,18 @@ where /// Note that unconverted fields are not upgraded, so they might need to be upgraded either before or /// after invoking this function. #[tracing::instrument(name = "modulus_conversion", skip_all, fields(bits = ?bit_range, gate = %ctx.gate().as_ref()))] -pub fn convert_selected_bits<'inp, F, V, C, S, VS, R>( +pub fn convert_selected_bits<'a, F, V, C, S, VS, R>( ctx: C, binary_shares: VS, bit_range: Range, -) -> impl Stream, R), Error>> + 'inp +) -> impl Stream, R), Error>> + 'a where R: Send + 'static, F: PrimeField, - V: ToBitConversionTriples + 'inp, - C: UpgradedContext + 'inp, + V: ToBitConversionTriples + 'a, + C: UpgradedContext + 'a, S: LinearSecretSharing + SecureMul, - VS: Stream + Unpin + Send + 'inp, + VS: Stream + Unpin + Send + 'a, for<'u> UpgradeContext<'u, C, F, RecordId>: UpgradeToMalicious<'u, BitConversionTriple>, BitConversionTriple>, { From 2c5aa174d2417411915233bffc741bc21dc97b8c Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Mon, 22 Jan 2024 10:50:11 -0800 Subject: [PATCH 23/27] Apply suggestions from code review Thanks for thorough review! Co-authored-by: Martin Thomson --- ipa-core/src/seq_join.rs | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/ipa-core/src/seq_join.rs b/ipa-core/src/seq_join.rs index af9601da4..14f30dc0d 100644 --- a/ipa-core/src/seq_join.rs +++ b/ipa-core/src/seq_join.rs @@ -293,8 +293,7 @@ mod multi_thread { #[cfg(feature = "shuttle")] mod shuttle_spawner { use shuttle_crate::{ - future, - future::{JoinError, JoinHandle}, + future::{self, JoinError, JoinHandle}, }; use super::*; @@ -334,7 +333,7 @@ mod multi_thread { } #[pin_project] - #[must_use = "Futures do nothing, unless polled"] + #[must_use = "Futures do nothing unless polled"] pub struct SequentialFutures<'fut, S, F> where S: Stream + Send + 'fut, @@ -386,18 +385,18 @@ mod multi_thread { // the behavior we want. this.spawner .spawn_cancellable(f.into_future().instrument(Span::current()), || { - panic!("cancelled") + panic!("SequentialFutures: spawned task cancelled") }); } else { break; } } - // Poll spawner if it has work to do. If both source and spawner are empty, we're done + // Poll spawner if it has work to do. If both source and spawner are empty, we're done. if this.spawner.remaining() > 0 { this.spawner.as_mut().poll_next(cx).map(|v| match v { Some(Ok(v)) => Some(v), - Some(Err(_)) => panic!("task is cancelled"), + Some(Err(_)) => panic!("SequentialFutures: spawned task aborted"), None => None, }) } else if this.source.is_done() { @@ -434,7 +433,7 @@ mod multi_thread { // If there is a dependency between futures, pending one will never complete. // Cancellable futures will be cancelled when spawner is dropped which is the behavior we want. scope.spawn_cancellable(element.instrument(Span::current()), || { - panic!("Future is cancelled.") + panic!("parallel_join: task cancelled") }); } scope @@ -444,7 +443,7 @@ mod multi_thread { let mut result = Vec::with_capacity(scope.len()); while let Some(item) = scope.next().await { // join error is nothing we can do about - result.push(item.unwrap()?) + result.push(item.expect("parallel_join: received JoinError")?) } Ok(result) } @@ -669,7 +668,7 @@ mod test { /// than one thread available. /// /// This behavior is only applicable when `seq_try_join_all` uses more than one thread, for - /// maintenance reasons, we use it even parallelism is turned off. + /// maintenance reasons, we use it even when parallelism is turned off. #[test] fn try_join_early_abort() { const ERROR: &str = "error message"; From 66d9ed5a0f67e26d4bcbac8cc8e41397be05190e Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Mon, 22 Jan 2024 19:26:44 -0800 Subject: [PATCH 24/27] More feedback --- ipa-core/src/lib.rs | 12 ++++---- ipa-core/src/seq_join.rs | 59 ++++++++++++++++++++++++++-------------- 2 files changed, 44 insertions(+), 27 deletions(-) diff --git a/ipa-core/src/lib.rs b/ipa-core/src/lib.rs index 7082448fd..b5fb7924a 100644 --- a/ipa-core/src/lib.rs +++ b/ipa-core/src/lib.rs @@ -112,19 +112,19 @@ pub(crate) mod test_executor { run(f); } - pub fn run(f: F) + pub fn run(f: F) -> T where F: Fn() -> Fut + Send + Sync + 'static, - Fut: Future, + Fut: Future, { tokio::runtime::Builder::new_multi_thread() - // IO driver is disabled to run our tests under Miri. If you need it, make sure you - // annotate this test with #[cfg(not(miri))] - // https://github.com/rust-lang/miri/issues/2057 + // enable_all() is common to use to build Tokio runtime, but it enables both IO and time drivers. + // IO driver is not compatible with Miri (https://github.com/rust-lang/miri/issues/2057) which we use to + // sanitize our tests, so this runtime only enables time driver. .enable_time() .build() .unwrap() - .block_on(f()); + .block_on(f()) } } diff --git a/ipa-core/src/seq_join.rs b/ipa-core/src/seq_join.rs index 14f30dc0d..146ec2e9c 100644 --- a/ipa-core/src/seq_join.rs +++ b/ipa-core/src/seq_join.rs @@ -157,11 +157,10 @@ where { } -#[cfg(feature = "multi-threading")] -pub type SequentialFutures<'fut, S, F> = multi_thread::SequentialFutures<'fut, S, F>; - #[cfg(not(feature = "multi-threading"))] -pub type SequentialFutures<'unused, S, F> = local::SequentialFutures<'unused, S, F>; +pub use local::SequentialFutures; +#[cfg(feature = "multi-threading")] +pub use multi_thread::SequentialFutures; /// Parallel and sequential join that use at most one thread. Good for unit testing and debugging, /// to get results in predictable order with fewer things happening at the same time. @@ -292,9 +291,7 @@ mod multi_thread { #[cfg(feature = "shuttle")] mod shuttle_spawner { - use shuttle_crate::{ - future::{self, JoinError, JoinHandle}, - }; + use shuttle_crate::future::{self, JoinError, JoinHandle}; use super::*; @@ -426,9 +423,8 @@ mod multi_thread { E: Send + 'static, { let mut scope = { - let iter = iterable.into_iter(); let mut scope = unsafe { create_spawner() }; - for element in iter { + for element in iterable { // it is important to make those cancellable to avoid deadlocks if one of the spawned future panics. // If there is a dependency between futures, pending one will never complete. // Cancellable futures will be cancelled when spawner is dropped which is the behavior we want. @@ -533,8 +529,8 @@ mod local_test { } /// A fully synchronous test with a synthetic stream, all the way to the end. - #[tokio::test] - async fn complete_stream() { + #[test] + fn complete_stream() { const VALUE: u32 = 20; const COUNT: usize = 7; let capacity = NonZeroUsize::new(3).unwrap(); @@ -708,18 +704,19 @@ mod test { } /// This test demonstrates that forgetting the future returned by `parallel_join` is not safe and will cause - /// use-after-free safety error. + /// use-after-free safety error. It spawns a few tasks that constantly try to access the `borrow_from_me` weak + /// reference while the main thread drops the owning reference. By proving that futures are able to see the weak + /// pointer unset, this test shows that same can happen for regular references and cause use-after-free. #[test] #[cfg(feature = "multi-threading")] fn parallel_join_forget_is_not_safe() { - use std::mem::ManuallyDrop; - use futures::future::poll_immediate; use crate::{seq_join::multi_thread::parallel_join, sync::Arc}; run(|| async { const N: usize = 24; + let borrowed_vec = Box::new([1, 2, 3]); let borrow_from_me = Arc::new(vec![1, 2, 3]); let start = Arc::new(tokio::sync::Barrier::new(N + 1)); // counts how many tasks have accessed `borrow_from_me` after it was destroyed. @@ -729,19 +726,27 @@ mod test { let futures = (0..N) .map(|_| { let borrowed = Arc::downgrade(&borrow_from_me); + let regular_ref = &borrowed_vec; let start = start.clone(); let bad_access = bad_accesses.clone(); async move { start.wait().await; - // at this point, the parent future is forgotten and borrowed should point to nothing for _ in 0..100 { if borrowed.upgrade().is_none() { bad_access.wait().await; + // switch to `true` if you want to see the real corruption. + #[allow(unreachable_code)] + if false { + // this is a place where we can see the use-after-free. + // we avoid executing this block to appease sanitizers, but compiler happily + // allows us to follow this reference. + println!("{:?}", regular_ref); + } break; } tokio::task::yield_now().await; } - Ok::<(), ()>(()) + Ok::<_, ()>(()) } }) .collect::>(); @@ -750,17 +755,29 @@ mod test { poll_immediate(&mut f).await; start.wait().await; - // forgetting f does not mean that futures spawned by `parallel_join` will be cancelled. - let guard = ManuallyDrop::new(f); + // the type of `f` above captures the lifetime for borrowed_vec. Leaking `f` allows `borrowed_vec` to be + // dropped, but that drop prohibits any subsequent manipulations with `f` pointer, irrespective of whether + // `f` is `&mut _` or `*mut _` (value already borrowed error). + // I am not sure I fully understand what is going on here (why borrowck allows me to leak the value, but + // then I can't drop it even if it is a raw pointer), but removing the lifetime from `f` type allows + // the test to pass. + // + // This is only required to do the proper cleanup and avoid memory leaks. Replacing this line with + // `mem::forget(f)` will lead to the same test outcome, but Miri will complain about memory leaks. + let f: _ = unsafe { + std::mem::transmute::<_, Pin, ()>> + Send>>>( + Box::pin(f) as Pin, ()>>>>, + ) + }; // Async executor will still be polling futures and they will try to follow this pointer. drop(borrow_from_me); + drop(borrowed_vec); // this test should terminate because all tasks should access `borrow_from_me` at least once. bad_accesses.wait().await; - // do not leak memory - let _ = ManuallyDrop::into_inner(guard); - }) + drop(f); + }); } } From 8dc3386d6f5f1e7109f1154fbffb0798dd0ed618 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Tue, 23 Jan 2024 10:31:40 -0800 Subject: [PATCH 25/27] Minor feedback --- ipa-core/src/seq_join.rs | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/ipa-core/src/seq_join.rs b/ipa-core/src/seq_join.rs index 146ec2e9c..f1987eb8b 100644 --- a/ipa-core/src/seq_join.rs +++ b/ipa-core/src/seq_join.rs @@ -167,6 +167,7 @@ pub use multi_thread::SequentialFutures; #[cfg(not(feature = "multi-threading"))] mod local { use std::{collections::VecDeque, marker::PhantomData}; + use futures::stream::Fuse; use super::*; @@ -200,7 +201,7 @@ mod local { #[must_use] fn take(self) -> F::Output { let ActiveItem::Resolved(v) = self else { - panic!("No value to take out"); + unreachable!("take should be only called once."); }; v @@ -214,7 +215,7 @@ mod local { F: IntoFuture, { #[pin] - source: futures::stream::Fuse, + source: Fuse, active: VecDeque>, _marker: PhantomData &'unused ()>, } @@ -285,6 +286,7 @@ mod local { /// version, so this is what we want to use in release/prod mode. #[cfg(feature = "multi-threading")] mod multi_thread { + use futures::stream::Fuse; use tracing::{Instrument, Span}; use super::*; @@ -340,7 +342,7 @@ mod multi_thread { #[pin] spawner: Spawner<'fut, F::Output>, #[pin] - source: futures::stream::Fuse, + source: Fuse, capacity: usize, } From 95b263757325a9fe80d90c526950e3e38d117949 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Tue, 23 Jan 2024 14:24:19 -0800 Subject: [PATCH 26/27] Split the `seq_join` file into separate modules One for local spawn and one for multi-threading --- .pre-commit.stashsIsbN1 | 0 ipa-core/src/seq_join.rs | 785 -------------------------- ipa-core/src/seq_join/local.rs | 268 +++++++++ ipa-core/src/seq_join/mod.rs | 274 +++++++++ ipa-core/src/seq_join/multi_thread.rs | 252 +++++++++ 5 files changed, 794 insertions(+), 785 deletions(-) create mode 100644 .pre-commit.stashsIsbN1 delete mode 100644 ipa-core/src/seq_join.rs create mode 100644 ipa-core/src/seq_join/local.rs create mode 100644 ipa-core/src/seq_join/mod.rs create mode 100644 ipa-core/src/seq_join/multi_thread.rs diff --git a/.pre-commit.stashsIsbN1 b/.pre-commit.stashsIsbN1 new file mode 100644 index 000000000..e69de29bb diff --git a/ipa-core/src/seq_join.rs b/ipa-core/src/seq_join.rs deleted file mode 100644 index f1987eb8b..000000000 --- a/ipa-core/src/seq_join.rs +++ /dev/null @@ -1,785 +0,0 @@ -use std::{ - future::IntoFuture, - num::NonZeroUsize, - pin::Pin, - task::{Context, Poll}, -}; - -use futures::{ - stream::{iter, Iter as StreamIter, TryCollect}, - Future, Stream, StreamExt, TryStreamExt, -}; -use pin_project::pin_project; - -use crate::exact::ExactSizeStream; - -/// This helper function might be necessary to convince the compiler that -/// the return value from [`seq_try_join_all`] implements `Send`. -/// Use this if you get higher-ranked lifetime errors that mention `std::marker::Send`. -/// -/// -pub fn assert_send<'a, O>( - fut: impl Future + Send + 'a, -) -> impl Future + Send + 'a { - fut -} - -/// Sequentially join futures from a stream. -/// -/// This function polls futures in strict sequence. -/// If any future blocks, up to `active - 1` futures after it will be polled so -/// that they make progress. -/// -/// # Deadlocks -/// -/// This will fail to resolve if the progress of any future depends on a future more -/// than `active` items behind it in the input sequence. -/// -/// # Safety -/// If multi-threading is enabled, forgetting the resulting future will cause use-after-free error. Do not leak it or -/// prevent the future destructor from running. -/// -/// [`try_join_all`]: futures::future::try_join_all -/// [`Stream`]: futures::stream::Stream -/// [`StreamExt::buffered`]: futures::stream::StreamExt::buffered -pub fn seq_join<'st, S, F, O>(active: NonZeroUsize, source: S) -> SequentialFutures<'st, S, F> -where - S: Stream + Send + 'st, - F: Future + Send, - O: Send + 'static, -{ - #[cfg(feature = "multi-threading")] - unsafe { - SequentialFutures::new(active, source) - } - #[cfg(not(feature = "multi-threading"))] - SequentialFutures::new(active, source) -} - -/// The `SeqJoin` trait wraps `seq_try_join_all`, providing the `active` parameter -/// from the provided context so that the value can be made consistent. -pub trait SeqJoin { - /// Perform a sequential join of the futures from the provided iterable. - /// This uses [`seq_join`], with the current state of the associated object - /// being used to determine the number of active items to track (see [`active_work`]). - /// - /// A rough rule of thumb for how to decide between this and [`parallel_join`] is - /// that this should be used whenever you are iterating over different records. - /// [`parallel_join`] is better suited to smaller batches, such as iterating over - /// the bits of a value for a single record. - /// - /// Note that the join functions from the [`futures`] crate, such as [`join3`], - /// are also parallel and can be used where you have a small, fixed number of tasks. - /// - /// Be especially careful if you use the random bits generator with this. - /// The random bits generator can produce values out of sequence. - /// You might need to use [`parallel_join`] for that. - /// - /// [`active_work`]: Self::active_work - /// [`parallel_join`]: Self::parallel_join - /// [`join3`]: futures::future::join3 - fn try_join<'fut, I, F, O, E>( - &self, - iterable: I, - ) -> TryCollect, Vec> - where - I: IntoIterator + Send, - I::IntoIter: Send + 'fut, - F: Future> + Send + 'fut, - O: Send + 'static, - E: Send + 'static, - { - seq_try_join_all(self.active_work(), iterable) - } - - /// Join multiple tasks in parallel. Only do this if you can't use a sequential join. - /// - /// # Safety - /// Forgetting the future returned from this function will cause use-after-free. This is a tradeoff between - /// performance and safety that allows us to use regular references instead of Arc pointers. - /// - /// Dropping the future is always safe. - #[cfg(feature = "multi-threading")] - fn parallel_join<'a, I, F, O, E>( - &self, - iterable: I, - ) -> Pin, E>> + Send + 'a>> - where - I: IntoIterator + Send, - F: Future> + Send + 'a, - O: Send + 'static, - E: Send + 'static, - { - unsafe { Box::pin(multi_thread::parallel_join(iterable)) } - } - - /// Join multiple tasks in parallel. Only do this if you can't use a sequential join. - #[cfg(not(feature = "multi-threading"))] - fn parallel_join(&self, iterable: I) -> futures::future::TryJoinAll - where - I: IntoIterator, - I::Item: futures::future::TryFuture, - { - #[allow(clippy::disallowed_methods)] // Just in this one place. - futures::future::try_join_all(iterable) - } - - /// The amount of active work that is concurrently permitted. - fn active_work(&self) -> NonZeroUsize; -} - -type SeqTryJoinAll<'st, I, F> = - SequentialFutures<'st, StreamIter<::IntoIter>, F>; - -/// A substitute for [`futures::future::try_join_all`] that uses [`seq_join`]. -/// This awaits all the provided futures in order, -/// aborting early if any future returns `Result::Err`. -pub fn seq_try_join_all<'iter, I, F, O, E>( - active: NonZeroUsize, - source: I, -) -> TryCollect, Vec> -where - I: IntoIterator + Send, - I::IntoIter: Send + 'iter, - F: Future> + Send + 'iter, - O: Send + 'static, - E: Send + 'static, -{ - seq_join(active, iter(source)).try_collect() -} - -impl<'fut, S, F> ExactSizeStream for SequentialFutures<'fut, S, F> -where - S: Stream + Send + ExactSizeStream, - F: IntoFuture, - ::IntoFuture: Send + 'fut, - <::IntoFuture as Future>::Output: Send + 'static, -{ -} - -#[cfg(not(feature = "multi-threading"))] -pub use local::SequentialFutures; -#[cfg(feature = "multi-threading")] -pub use multi_thread::SequentialFutures; - -/// Parallel and sequential join that use at most one thread. Good for unit testing and debugging, -/// to get results in predictable order with fewer things happening at the same time. -#[cfg(not(feature = "multi-threading"))] -mod local { - use std::{collections::VecDeque, marker::PhantomData}; - use futures::stream::Fuse; - - use super::*; - - enum ActiveItem { - Pending(Pin>), - Resolved(F::Output), - } - - impl ActiveItem { - /// Drives this item to resolved state when value is ready to be taken out. Has no effect - /// if the value is ready. - /// - /// ## Panics - /// Panics if this item is completed - fn check_ready(&mut self, cx: &mut Context<'_>) -> bool { - let ActiveItem::Pending(f) = self else { - return true; - }; - if let Poll::Ready(v) = Future::poll(Pin::as_mut(f), cx) { - *self = ActiveItem::Resolved(v); - true - } else { - false - } - } - - /// Takes the resolved value out - /// - /// ## Panics - /// If the value is not ready yet. - #[must_use] - fn take(self) -> F::Output { - let ActiveItem::Resolved(v) = self else { - unreachable!("take should be only called once."); - }; - - v - } - } - - #[pin_project] - pub struct SequentialFutures<'unused, S, F> - where - S: Stream + Send, - F: IntoFuture, - { - #[pin] - source: Fuse, - active: VecDeque>, - _marker: PhantomData &'unused ()>, - } - - impl SequentialFutures<'_, S, F> - where - S: Stream + Send, - F: IntoFuture, - { - pub fn new(active: NonZeroUsize, source: S) -> Self { - Self { - source: source.fuse(), - active: VecDeque::with_capacity(active.get()), - _marker: PhantomData, - } - } - } - - impl Stream for SequentialFutures<'_, S, F> - where - S: Stream + Send, - F: IntoFuture, - { - type Item = F::Output; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let mut this = self.project(); - - // Draw more values from the input, up to the capacity. - while this.active.len() < this.active.capacity() { - if let Poll::Ready(Some(f)) = this.source.as_mut().poll_next(cx) { - this.active - .push_back(ActiveItem::Pending(Box::pin(f.into_future()))); - } else { - break; - } - } - - if let Some(item) = this.active.front_mut() { - if item.check_ready(cx) { - let v = this.active.pop_front().map(ActiveItem::take); - Poll::Ready(v) - } else { - for f in this.active.iter_mut().skip(1) { - f.check_ready(cx); - } - Poll::Pending - } - } else if this.source.is_done() { - Poll::Ready(None) - } else { - Poll::Pending - } - } - - fn size_hint(&self) -> (usize, Option) { - let in_progress = self.active.len(); - let (lower, upper) = self.source.size_hint(); - ( - lower.saturating_add(in_progress), - upper.and_then(|u| u.checked_add(in_progress)), - ) - } - } -} - -/// Both joins use executor tasks to drive futures to completion. Much faster than single-threaded -/// version, so this is what we want to use in release/prod mode. -#[cfg(feature = "multi-threading")] -mod multi_thread { - use futures::stream::Fuse; - use tracing::{Instrument, Span}; - - use super::*; - - #[cfg(feature = "shuttle")] - mod shuttle_spawner { - use shuttle_crate::future::{self, JoinError, JoinHandle}; - - use super::*; - - /// Spawner implementation for Shuttle framework to run tests in parallel - pub(super) struct ShuttleSpawner; - - unsafe impl async_scoped::spawner::Spawner for ShuttleSpawner - where - T: Send + 'static, - { - type FutureOutput = Result; - type SpawnHandle = JoinHandle; - - fn spawn + Send + 'static>(&self, f: F) -> Self::SpawnHandle { - future::spawn(f) - } - } - - unsafe impl async_scoped::spawner::Blocker for ShuttleSpawner { - fn block_on>(&self, f: F) -> T { - future::block_on(f) - } - } - } - - #[cfg(feature = "shuttle")] - type Spawner<'fut, T> = async_scoped::Scope<'fut, T, shuttle_spawner::ShuttleSpawner>; - #[cfg(not(feature = "shuttle"))] - type Spawner<'fut, T> = async_scoped::TokioScope<'fut, T>; - - unsafe fn create_spawner<'fut, T: Send + 'static>() -> Spawner<'fut, T> { - #[cfg(feature = "shuttle")] - return async_scoped::Scope::create(shuttle_spawner::ShuttleSpawner); - #[cfg(not(feature = "shuttle"))] - return async_scoped::TokioScope::create(async_scoped::spawner::use_tokio::Tokio); - } - - #[pin_project] - #[must_use = "Futures do nothing unless polled"] - pub struct SequentialFutures<'fut, S, F> - where - S: Stream + Send + 'fut, - F: IntoFuture, - <::IntoFuture as Future>::Output: Send + 'static, - { - #[pin] - spawner: Spawner<'fut, F::Output>, - #[pin] - source: Fuse, - capacity: usize, - } - - impl SequentialFutures<'_, S, F> - where - S: Stream + Send, - F: IntoFuture, - <::IntoFuture as Future>::Output: Send + 'static, - { - pub unsafe fn new(active: NonZeroUsize, source: S) -> Self { - SequentialFutures { - spawner: unsafe { create_spawner() }, - source: source.fuse(), - capacity: active.get(), - } - } - } - - impl<'fut, S, F> Stream for SequentialFutures<'fut, S, F> - where - S: Stream + Send, - F: IntoFuture, - ::IntoFuture: Send + 'fut, - <::IntoFuture as Future>::Output: Send + 'static, - { - type Item = F::Output; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let mut this = self.project(); - - // Draw more values from the input, up to the capacity. - while this.spawner.remaining() < *this.capacity { - if let Poll::Ready(Some(f)) = this.source.as_mut().poll_next(cx) { - // Making futures cancellable is critical to avoid hangs. - // if one of them panics, unwinding causes spawner to drop and, in turn, - // it blocks the thread to await all pending futures completion. If there is - // a dependency between futures, pending one will never complete. - // Cancellable futures will be cancelled when spawner is dropped which is - // the behavior we want. - this.spawner - .spawn_cancellable(f.into_future().instrument(Span::current()), || { - panic!("SequentialFutures: spawned task cancelled") - }); - } else { - break; - } - } - - // Poll spawner if it has work to do. If both source and spawner are empty, we're done. - if this.spawner.remaining() > 0 { - this.spawner.as_mut().poll_next(cx).map(|v| match v { - Some(Ok(v)) => Some(v), - Some(Err(_)) => panic!("SequentialFutures: spawned task aborted"), - None => None, - }) - } else if this.source.is_done() { - Poll::Ready(None) - } else { - Poll::Pending - } - } - - fn size_hint(&self) -> (usize, Option) { - let in_progress = self.spawner.remaining(); - let (lower, upper) = self.source.size_hint(); - ( - lower.saturating_add(in_progress), - upper.and_then(|u| u.checked_add(in_progress)), - ) - } - } - - pub(super) unsafe fn parallel_join<'fut, I, F, O, E>( - iterable: I, - ) -> impl Future, E>> + Send + 'fut - where - I: IntoIterator + Send, - F: Future> + Send + 'fut, - O: Send + 'static, - E: Send + 'static, - { - let mut scope = { - let mut scope = unsafe { create_spawner() }; - for element in iterable { - // it is important to make those cancellable to avoid deadlocks if one of the spawned future panics. - // If there is a dependency between futures, pending one will never complete. - // Cancellable futures will be cancelled when spawner is dropped which is the behavior we want. - scope.spawn_cancellable(element.instrument(Span::current()), || { - panic!("parallel_join: task cancelled") - }); - } - scope - }; - - async move { - let mut result = Vec::with_capacity(scope.len()); - while let Some(item) = scope.next().await { - // join error is nothing we can do about - result.push(item.expect("parallel_join: received JoinError")?) - } - Ok(result) - } - } -} - -#[cfg(all(test, unit_test, not(feature = "multi-threading")))] -mod local_test { - use std::{ - num::NonZeroUsize, - ptr::null, - sync::{Arc, Mutex}, - task::{Context, Poll, Waker}, - }; - - use futures::{ - future::lazy, - stream::{poll_fn, repeat_with}, - StreamExt, - }; - - use super::*; - - fn fake_waker() -> Waker { - use std::task::{RawWaker, RawWakerVTable}; - const fn fake_raw_waker() -> RawWaker { - const TABLE: RawWakerVTable = - RawWakerVTable::new(|_| fake_raw_waker(), |_| {}, |_| {}, |_| {}); - RawWaker::new(null(), &TABLE) - } - unsafe { Waker::from_raw(fake_raw_waker()) } - } - - /// Check the value of a counter, then reset it. - fn assert_count(counter_r: &Arc>, expected: usize) { - let mut counter = counter_r.lock().unwrap(); - assert_eq!(*counter, expected); - *counter = 0; - } - - /// A fully synchronous test. - #[test] - fn synchronous() { - let capacity = NonZeroUsize::new(3).unwrap(); - let v_r: Arc>> = Arc::new(Mutex::new(None)); - let v_w = Arc::clone(&v_r); - // Track when the stream was polled, - let polled_w: Arc> = Arc::new(Mutex::new(0)); - let polled_r = Arc::clone(&polled_w); - // when the stream produced something, and - let produced_w: Arc> = Arc::new(Mutex::new(0)); - let produced_r = Arc::clone(&produced_w); - // when the future was read. - let read_w: Arc> = Arc::new(Mutex::new(0)); - let read_r = Arc::clone(&read_w); - - let stream = poll_fn(|_cx| { - *polled_w.lock().unwrap() += 1; - if let Some(v) = v_r.lock().unwrap().take() { - *produced_w.lock().unwrap() += 1; - let read_w = Arc::clone(&read_w); - Poll::Ready(Some(lazy(move |_| { - *read_w.lock().unwrap() += 1; - v - }))) - } else { - // Note: we can ignore `cx` because we are driving this directly. - Poll::Pending - } - }); - let mut joined = seq_join(capacity, stream); - let waker = fake_waker(); - let mut cx = Context::from_waker(&waker); - - let res = joined.poll_next_unpin(&mut cx); - assert_count(&polled_r, 1); - assert_count(&produced_r, 0); - assert_count(&read_r, 0); - assert!(res.is_pending()); - - *v_w.lock().unwrap() = Some(7); - let res = joined.poll_next_unpin(&mut cx); - assert_count(&polled_r, 2); - assert_count(&produced_r, 1); - assert_count(&read_r, 1); - assert!(matches!(res, Poll::Ready(Some(7)))); - } - - /// A fully synchronous test with a synthetic stream, all the way to the end. - #[test] - fn complete_stream() { - const VALUE: u32 = 20; - const COUNT: usize = 7; - let capacity = NonZeroUsize::new(3).unwrap(); - // Track the number of values produced. - let produced_w: Arc> = Arc::new(Mutex::new(0)); - let produced_r = Arc::clone(&produced_w); - - let stream = repeat_with(|| { - *produced_w.lock().unwrap() += 1; - lazy(|_| VALUE) - }) - .take(COUNT); - let mut joined = seq_join(capacity, stream); - let waker = fake_waker(); - let mut cx = Context::from_waker(&waker); - - // The first poll causes the active buffer to be filled if that is possible. - let res = joined.poll_next_unpin(&mut cx); - assert_count(&produced_r, capacity.get()); - assert!(matches!(res, Poll::Ready(Some(VALUE)))); - - // A few more iterations, where each top up the buffer. - for _ in 0..(COUNT - capacity.get()) { - let res = joined.poll_next_unpin(&mut cx); - assert_count(&produced_r, 1); - assert!(matches!(res, Poll::Ready(Some(VALUE)))); - } - - // Then we drain the buffer. - for _ in 0..(capacity.get() - 1) { - let res = joined.poll_next_unpin(&mut cx); - assert_count(&produced_r, 0); - assert!(matches!(res, Poll::Ready(Some(VALUE)))); - } - - // Then the stream ends. - let res = joined.poll_next_unpin(&mut cx); - assert_count(&produced_r, 0); - assert!(matches!(res, Poll::Ready(None))); - } -} - -#[cfg(all(test, unit_test))] -mod test { - use std::{convert::Infallible, iter::once}; - - use futures::{ - future::{lazy, BoxFuture}, - stream::{iter, poll_immediate}, - Future, StreamExt, - }; - - use super::*; - use crate::test_executor::run; - - async fn immediate(count: u32) { - let capacity = NonZeroUsize::new(3).unwrap(); - let values = seq_join(capacity, iter((0..count).map(|i| async move { i }))) - .collect::>() - .await; - assert_eq!((0..count).collect::>(), values); - } - - #[test] - fn within_capacity() { - run(|| async { - immediate(2).await; - immediate(1).await; - }); - } - - #[test] - fn over_capacity() { - run(|| async { - immediate(10).await; - }); - } - - #[test] - fn size() { - run(|| async { - let mut count = 10_usize; - let capacity = NonZeroUsize::new(3).unwrap(); - let mut values = seq_join(capacity, iter((0..count).map(|i| async move { i }))); - assert_eq!((count, Some(count)), values.size_hint()); - - while values.next().await.is_some() { - count -= 1; - assert_eq!((count, Some(count)), values.size_hint()); - } - }); - } - - #[test] - fn out_of_order() { - run(|| async { - let capacity = NonZeroUsize::new(3).unwrap(); - let barrier = tokio::sync::Barrier::new(2); - let unresolved: BoxFuture<'_, u32> = Box::pin(async { - barrier.wait().await; - 0 - }); - let it = once(unresolved) - .chain((1..4_u32).map(|i| -> BoxFuture<'_, u32> { Box::pin(async move { i }) })); - let mut seq_futures = seq_join(capacity, iter(it)); - - assert_eq!( - Some(Poll::Pending), - poll_immediate(&mut seq_futures).next().await - ); - barrier.wait().await; - assert_eq!(vec![0, 1, 2, 3], seq_futures.collect::>().await); - }); - } - - #[test] - fn join_success() { - fn f(v: T) -> impl Future> { - lazy(move |_| Ok(v)) - } - - run(|| async { - let active = NonZeroUsize::new(10).unwrap(); - let res = seq_try_join_all(active, (1..5).map(f)).await.unwrap(); - assert_eq!((1..5).collect::>(), res); - }); - } - - /// This test has to use multi-threaded runtime because early return causes `TryCollect` to be - /// dropped and the remaining futures to be cancelled which can only happen if there is more - /// than one thread available. - /// - /// This behavior is only applicable when `seq_try_join_all` uses more than one thread, for - /// maintenance reasons, we use it even when parallelism is turned off. - #[test] - fn try_join_early_abort() { - const ERROR: &str = "error message"; - fn f(i: u32) -> impl Future> { - lazy(move |_| match i { - 1 => Ok(1), - 2 => Err(ERROR), - _ => panic!("should have aborted earlier"), - }) - } - - run(|| async { - let active = NonZeroUsize::new(10).unwrap(); - let err = seq_try_join_all(active, (1..=3).map(f)).await.unwrap_err(); - assert_eq!(err, ERROR); - }); - } - - #[test] - fn does_not_block_on_error() { - const ERROR: &str = "returning early is safe"; - use std::pin::Pin; - - fn f(i: u32) -> Pin> + Send>> { - match i { - 1 => Box::pin(lazy(move |_| Ok(1))), - 2 => Box::pin(lazy(move |_| Err(ERROR))), - _ => Box::pin(futures::future::pending()), - } - } - - run(|| async { - let active = NonZeroUsize::new(10).unwrap(); - let err = seq_try_join_all(active, (1..=3).map(f)).await.unwrap_err(); - assert_eq!(err, ERROR); - }); - } - - /// This test demonstrates that forgetting the future returned by `parallel_join` is not safe and will cause - /// use-after-free safety error. It spawns a few tasks that constantly try to access the `borrow_from_me` weak - /// reference while the main thread drops the owning reference. By proving that futures are able to see the weak - /// pointer unset, this test shows that same can happen for regular references and cause use-after-free. - #[test] - #[cfg(feature = "multi-threading")] - fn parallel_join_forget_is_not_safe() { - use futures::future::poll_immediate; - - use crate::{seq_join::multi_thread::parallel_join, sync::Arc}; - - run(|| async { - const N: usize = 24; - let borrowed_vec = Box::new([1, 2, 3]); - let borrow_from_me = Arc::new(vec![1, 2, 3]); - let start = Arc::new(tokio::sync::Barrier::new(N + 1)); - // counts how many tasks have accessed `borrow_from_me` after it was destroyed. - // this test expects all tasks to access `borrow_from_me` at least once. - let bad_accesses = Arc::new(tokio::sync::Barrier::new(N + 1)); - - let futures = (0..N) - .map(|_| { - let borrowed = Arc::downgrade(&borrow_from_me); - let regular_ref = &borrowed_vec; - let start = start.clone(); - let bad_access = bad_accesses.clone(); - async move { - start.wait().await; - for _ in 0..100 { - if borrowed.upgrade().is_none() { - bad_access.wait().await; - // switch to `true` if you want to see the real corruption. - #[allow(unreachable_code)] - if false { - // this is a place where we can see the use-after-free. - // we avoid executing this block to appease sanitizers, but compiler happily - // allows us to follow this reference. - println!("{:?}", regular_ref); - } - break; - } - tokio::task::yield_now().await; - } - Ok::<_, ()>(()) - } - }) - .collect::>(); - - let mut f = Box::pin(unsafe { parallel_join(futures) }); - poll_immediate(&mut f).await; - start.wait().await; - - // the type of `f` above captures the lifetime for borrowed_vec. Leaking `f` allows `borrowed_vec` to be - // dropped, but that drop prohibits any subsequent manipulations with `f` pointer, irrespective of whether - // `f` is `&mut _` or `*mut _` (value already borrowed error). - // I am not sure I fully understand what is going on here (why borrowck allows me to leak the value, but - // then I can't drop it even if it is a raw pointer), but removing the lifetime from `f` type allows - // the test to pass. - // - // This is only required to do the proper cleanup and avoid memory leaks. Replacing this line with - // `mem::forget(f)` will lead to the same test outcome, but Miri will complain about memory leaks. - let f: _ = unsafe { - std::mem::transmute::<_, Pin, ()>> + Send>>>( - Box::pin(f) as Pin, ()>>>>, - ) - }; - - // Async executor will still be polling futures and they will try to follow this pointer. - drop(borrow_from_me); - drop(borrowed_vec); - - // this test should terminate because all tasks should access `borrow_from_me` at least once. - bad_accesses.wait().await; - - drop(f); - }); - } -} diff --git a/ipa-core/src/seq_join/local.rs b/ipa-core/src/seq_join/local.rs new file mode 100644 index 000000000..33fe3d757 --- /dev/null +++ b/ipa-core/src/seq_join/local.rs @@ -0,0 +1,268 @@ +use std::{ + collections::VecDeque, + future::IntoFuture, + marker::PhantomData, + num::NonZeroUsize, + pin::Pin, + task::{Context, Poll}, +}; + +use futures::stream::Fuse; + +use super::*; + +enum ActiveItem { + Pending(Pin>), + Resolved(F::Output), +} + +impl ActiveItem { + /// Drives this item to resolved state when value is ready to be taken out. Has no effect + /// if the value is ready. + /// + /// ## Panics + /// Panics if this item is completed + fn check_ready(&mut self, cx: &mut Context<'_>) -> bool { + let ActiveItem::Pending(f) = self else { + return true; + }; + if let Poll::Ready(v) = Future::poll(Pin::as_mut(f), cx) { + *self = ActiveItem::Resolved(v); + true + } else { + false + } + } + + /// Takes the resolved value out + /// + /// ## Panics + /// If the value is not ready yet. + #[must_use] + fn take(self) -> F::Output { + let ActiveItem::Resolved(v) = self else { + unreachable!("take should be only called once."); + }; + + v + } +} + +#[pin_project] +pub struct SequentialFutures<'unused, S, F> +where + S: Stream + Send, + F: IntoFuture, +{ + #[pin] + source: Fuse, + active: VecDeque>, + _marker: PhantomData &'unused ()>, +} + +impl SequentialFutures<'_, S, F> +where + S: Stream + Send, + F: IntoFuture, +{ + pub fn new(active: NonZeroUsize, source: S) -> Self { + Self { + source: source.fuse(), + active: VecDeque::with_capacity(active.get()), + _marker: PhantomData, + } + } +} + +impl Stream for SequentialFutures<'_, S, F> +where + S: Stream + Send, + F: IntoFuture, +{ + type Item = F::Output; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + + // Draw more values from the input, up to the capacity. + while this.active.len() < this.active.capacity() { + if let Poll::Ready(Some(f)) = this.source.as_mut().poll_next(cx) { + this.active + .push_back(ActiveItem::Pending(Box::pin(f.into_future()))); + } else { + break; + } + } + + if let Some(item) = this.active.front_mut() { + if item.check_ready(cx) { + let v = this.active.pop_front().map(ActiveItem::take); + Poll::Ready(v) + } else { + for f in this.active.iter_mut().skip(1) { + f.check_ready(cx); + } + Poll::Pending + } + } else if this.source.is_done() { + Poll::Ready(None) + } else { + Poll::Pending + } + } + + fn size_hint(&self) -> (usize, Option) { + let in_progress = self.active.len(); + let (lower, upper) = self.source.size_hint(); + ( + lower.saturating_add(in_progress), + upper.and_then(|u| u.checked_add(in_progress)), + ) + } +} + +#[cfg(all(test, unit_test))] +mod local_test { + use std::{ + num::NonZeroUsize, + ptr::null, + sync::{Arc, Mutex}, + task::{Context, Poll, Waker}, + }; + + use futures::{ + future::lazy, + stream::{poll_fn, repeat_with}, + StreamExt, + }; + + use super::*; + use crate::test_executor::run; + + fn fake_waker() -> Waker { + use std::task::{RawWaker, RawWakerVTable}; + const fn fake_raw_waker() -> RawWaker { + const TABLE: RawWakerVTable = + RawWakerVTable::new(|_| fake_raw_waker(), |_| {}, |_| {}, |_| {}); + RawWaker::new(null(), &TABLE) + } + unsafe { Waker::from_raw(fake_raw_waker()) } + } + + /// Check the value of a counter, then reset it. + fn assert_count(counter_r: &Arc>, expected: usize) { + let mut counter = counter_r.lock().unwrap(); + assert_eq!(*counter, expected); + *counter = 0; + } + + /// A fully synchronous test. + #[test] + fn synchronous() { + let capacity = NonZeroUsize::new(3).unwrap(); + let v_r: Arc>> = Arc::new(Mutex::new(None)); + let v_w = Arc::clone(&v_r); + // Track when the stream was polled, + let polled_w: Arc> = Arc::new(Mutex::new(0)); + let polled_r = Arc::clone(&polled_w); + // when the stream produced something, and + let produced_w: Arc> = Arc::new(Mutex::new(0)); + let produced_r = Arc::clone(&produced_w); + // when the future was read. + let read_w: Arc> = Arc::new(Mutex::new(0)); + let read_r = Arc::clone(&read_w); + + let stream = poll_fn(|_cx| { + *polled_w.lock().unwrap() += 1; + if let Some(v) = v_r.lock().unwrap().take() { + *produced_w.lock().unwrap() += 1; + let read_w = Arc::clone(&read_w); + Poll::Ready(Some(lazy(move |_| { + *read_w.lock().unwrap() += 1; + v + }))) + } else { + // Note: we can ignore `cx` because we are driving this directly. + Poll::Pending + } + }); + let mut joined = seq_join(capacity, stream); + let waker = fake_waker(); + let mut cx = Context::from_waker(&waker); + + let res = joined.poll_next_unpin(&mut cx); + assert_count(&polled_r, 1); + assert_count(&produced_r, 0); + assert_count(&read_r, 0); + assert!(res.is_pending()); + + *v_w.lock().unwrap() = Some(7); + let res = joined.poll_next_unpin(&mut cx); + assert_count(&polled_r, 2); + assert_count(&produced_r, 1); + assert_count(&read_r, 1); + assert!(matches!(res, Poll::Ready(Some(7)))); + } + + /// A fully synchronous test with a synthetic stream, all the way to the end. + #[test] + fn complete_stream() { + const VALUE: u32 = 20; + const COUNT: usize = 7; + let capacity = NonZeroUsize::new(3).unwrap(); + // Track the number of values produced. + let produced_w: Arc> = Arc::new(Mutex::new(0)); + let produced_r = Arc::clone(&produced_w); + + let stream = repeat_with(|| { + *produced_w.lock().unwrap() += 1; + lazy(|_| VALUE) + }) + .take(COUNT); + let mut joined = seq_join(capacity, stream); + let waker = fake_waker(); + let mut cx = Context::from_waker(&waker); + + // The first poll causes the active buffer to be filled if that is possible. + let res = joined.poll_next_unpin(&mut cx); + assert_count(&produced_r, capacity.get()); + assert!(matches!(res, Poll::Ready(Some(VALUE)))); + + // A few more iterations, where each top up the buffer. + for _ in 0..(COUNT - capacity.get()) { + let res = joined.poll_next_unpin(&mut cx); + assert_count(&produced_r, 1); + assert!(matches!(res, Poll::Ready(Some(VALUE)))); + } + + // Then we drain the buffer. + for _ in 0..(capacity.get() - 1) { + let res = joined.poll_next_unpin(&mut cx); + assert_count(&produced_r, 0); + assert!(matches!(res, Poll::Ready(Some(VALUE)))); + } + + // Then the stream ends. + let res = joined.poll_next_unpin(&mut cx); + assert_count(&produced_r, 0); + assert!(matches!(res, Poll::Ready(None))); + } + + #[test] + fn try_join_early_abort() { + const ERROR: &str = "error message"; + fn f(i: u32) -> impl Future> { + lazy(move |_| match i { + 1 => Ok(1), + 2 => Err(ERROR), + _ => panic!("should have aborted earlier"), + }) + } + + run(|| async { + let active = NonZeroUsize::new(10).unwrap(); + let err = seq_try_join_all(active, (1..=3).map(f)).await.unwrap_err(); + assert_eq!(err, ERROR); + }); + } +} diff --git a/ipa-core/src/seq_join/mod.rs b/ipa-core/src/seq_join/mod.rs new file mode 100644 index 000000000..dfe6b1073 --- /dev/null +++ b/ipa-core/src/seq_join/mod.rs @@ -0,0 +1,274 @@ +use std::{future::IntoFuture, num::NonZeroUsize}; + +use futures::{ + stream::{iter, Iter as StreamIter, TryCollect}, + Future, Stream, StreamExt, TryStreamExt, +}; +use pin_project::pin_project; + +use crate::exact::ExactSizeStream; + +#[cfg(not(feature = "multi-threading"))] +mod local; +#[cfg(feature = "multi-threading")] +mod multi_thread; + +/// This helper function might be necessary to convince the compiler that +/// the return value from [`seq_try_join_all`] implements `Send`. +/// Use this if you get higher-ranked lifetime errors that mention `std::marker::Send`. +/// +/// +pub fn assert_send<'a, O>( + fut: impl Future + Send + 'a, +) -> impl Future + Send + 'a { + fut +} + +/// Sequentially join futures from a stream. +/// +/// This function polls futures in strict sequence. +/// If any future blocks, up to `active - 1` futures after it will be polled so +/// that they make progress. +/// +/// # Deadlocks +/// +/// This will fail to resolve if the progress of any future depends on a future more +/// than `active` items behind it in the input sequence. +/// +/// # Safety +/// If multi-threading is enabled, forgetting the resulting future will cause use-after-free error. Do not leak it or +/// prevent the future destructor from running. +/// +/// [`try_join_all`]: futures::future::try_join_all +/// [`Stream`]: futures::stream::Stream +/// [`StreamExt::buffered`]: futures::stream::StreamExt::buffered +pub fn seq_join<'st, S, F, O>(active: NonZeroUsize, source: S) -> SequentialFutures<'st, S, F> +where + S: Stream + Send + 'st, + F: Future + Send, + O: Send + 'static, +{ + #[cfg(feature = "multi-threading")] + unsafe { + SequentialFutures::new(active, source) + } + #[cfg(not(feature = "multi-threading"))] + SequentialFutures::new(active, source) +} + +/// The `SeqJoin` trait wraps `seq_try_join_all`, providing the `active` parameter +/// from the provided context so that the value can be made consistent. +pub trait SeqJoin { + /// Perform a sequential join of the futures from the provided iterable. + /// This uses [`seq_join`], with the current state of the associated object + /// being used to determine the number of active items to track (see [`active_work`]). + /// + /// A rough rule of thumb for how to decide between this and [`parallel_join`] is + /// that this should be used whenever you are iterating over different records. + /// [`parallel_join`] is better suited to smaller batches, such as iterating over + /// the bits of a value for a single record. + /// + /// Note that the join functions from the [`futures`] crate, such as [`join3`], + /// are also parallel and can be used where you have a small, fixed number of tasks. + /// + /// Be especially careful if you use the random bits generator with this. + /// The random bits generator can produce values out of sequence. + /// You might need to use [`parallel_join`] for that. + /// + /// [`active_work`]: Self::active_work + /// [`parallel_join`]: Self::parallel_join + /// [`join3`]: futures::future::join3 + fn try_join<'fut, I, F, O, E>( + &self, + iterable: I, + ) -> TryCollect, Vec> + where + I: IntoIterator + Send, + I::IntoIter: Send + 'fut, + F: Future> + Send + 'fut, + O: Send + 'static, + E: Send + 'static, + { + seq_try_join_all(self.active_work(), iterable) + } + + /// Join multiple tasks in parallel. Only do this if you can't use a sequential join. + /// + /// # Safety + /// Forgetting the future returned from this function will cause use-after-free. This is a tradeoff between + /// performance and safety that allows us to use regular references instead of Arc pointers. + /// + /// Dropping the future is always safe. + #[cfg(feature = "multi-threading")] + fn parallel_join<'a, I, F, O, E>( + &self, + iterable: I, + ) -> std::pin::Pin, E>> + Send + 'a>> + where + I: IntoIterator + Send, + F: Future> + Send + 'a, + O: Send + 'static, + E: Send + 'static, + { + unsafe { Box::pin(multi_thread::parallel_join(iterable)) } + } + + /// Join multiple tasks in parallel. Only do this if you can't use a sequential join. + #[cfg(not(feature = "multi-threading"))] + fn parallel_join(&self, iterable: I) -> futures::future::TryJoinAll + where + I: IntoIterator, + I::Item: futures::future::TryFuture, + { + #[allow(clippy::disallowed_methods)] // Just in this one place. + futures::future::try_join_all(iterable) + } + + /// The amount of active work that is concurrently permitted. + fn active_work(&self) -> NonZeroUsize; +} + +type SeqTryJoinAll<'st, I, F> = + SequentialFutures<'st, StreamIter<::IntoIter>, F>; + +/// A substitute for [`futures::future::try_join_all`] that uses [`seq_join`]. +/// This awaits all the provided futures in order, +/// aborting early if any future returns `Result::Err`. +pub fn seq_try_join_all<'iter, I, F, O, E>( + active: NonZeroUsize, + source: I, +) -> TryCollect, Vec> +where + I: IntoIterator + Send, + I::IntoIter: Send + 'iter, + F: Future> + Send + 'iter, + O: Send + 'static, + E: Send + 'static, +{ + seq_join(active, iter(source)).try_collect() +} + +impl<'fut, S, F> ExactSizeStream for SequentialFutures<'fut, S, F> +where + S: Stream + Send + ExactSizeStream, + F: IntoFuture, + ::IntoFuture: Send + 'fut, + <::IntoFuture as Future>::Output: Send + 'static, +{ +} + +#[cfg(not(feature = "multi-threading"))] +pub use local::SequentialFutures; +#[cfg(feature = "multi-threading")] +pub use multi_thread::SequentialFutures; + +#[cfg(all(test, any(unit_test, feature = "shuttle")))] +mod test { + use std::{convert::Infallible, iter::once, task::Poll}; + + use futures::{ + future::{lazy, BoxFuture}, + stream::{iter, poll_immediate}, + Future, StreamExt, + }; + + use super::*; + use crate::test_executor::run; + + async fn immediate(count: u32) { + let capacity = NonZeroUsize::new(3).unwrap(); + let values = seq_join(capacity, iter((0..count).map(|i| async move { i }))) + .collect::>() + .await; + assert_eq!((0..count).collect::>(), values); + } + + #[test] + fn within_capacity() { + run(|| async { + immediate(2).await; + immediate(1).await; + }); + } + + #[test] + fn over_capacity() { + run(|| async { + immediate(10).await; + }); + } + + #[test] + fn size() { + run(|| async { + let mut count = 10_usize; + let capacity = NonZeroUsize::new(3).unwrap(); + let mut values = seq_join(capacity, iter((0..count).map(|i| async move { i }))); + assert_eq!((count, Some(count)), values.size_hint()); + + while values.next().await.is_some() { + count -= 1; + assert_eq!((count, Some(count)), values.size_hint()); + } + }); + } + + #[test] + fn out_of_order() { + run(|| async { + let capacity = NonZeroUsize::new(3).unwrap(); + let barrier = tokio::sync::Barrier::new(2); + let unresolved: BoxFuture<'_, u32> = Box::pin(async { + barrier.wait().await; + 0 + }); + let it = once(unresolved) + .chain((1..4_u32).map(|i| -> BoxFuture<'_, u32> { Box::pin(async move { i }) })); + let mut seq_futures = seq_join(capacity, iter(it)); + + assert_eq!( + Some(Poll::Pending), + poll_immediate(&mut seq_futures).next().await + ); + barrier.wait().await; + assert_eq!(vec![0, 1, 2, 3], seq_futures.collect::>().await); + }); + } + + #[test] + fn join_success() { + fn f(v: T) -> impl Future> { + lazy(move |_| Ok(v)) + } + + run(|| async { + let active = NonZeroUsize::new(10).unwrap(); + let res = seq_try_join_all(active, (1..5).map(f)).await.unwrap(); + assert_eq!((1..5).collect::>(), res); + }); + } + + #[test] + #[cfg_attr( + all(feature = "shuttle", feature = "multi-threading"), + should_panic(expected = "cancelled") + )] + fn does_not_block_on_error() { + const ERROR: &str = "returning early is safe"; + use std::pin::Pin; + + fn f(i: u32) -> Pin> + Send>> { + match i { + 1 => Box::pin(lazy(move |_| Ok(1))), + 2 => Box::pin(lazy(move |_| Err(ERROR))), + _ => Box::pin(futures::future::pending()), + } + } + + run(|| async { + let active = NonZeroUsize::new(10).unwrap(); + let err = seq_try_join_all(active, (1..=3).map(f)).await.unwrap_err(); + assert_eq!(err, ERROR); + }); + } +} diff --git a/ipa-core/src/seq_join/multi_thread.rs b/ipa-core/src/seq_join/multi_thread.rs new file mode 100644 index 000000000..79ee89d6e --- /dev/null +++ b/ipa-core/src/seq_join/multi_thread.rs @@ -0,0 +1,252 @@ +use std::{ + future::IntoFuture, + num::NonZeroUsize, + pin::Pin, + task::{Context, Poll}, +}; + +use futures::stream::Fuse; +use tracing::{Instrument, Span}; + +use super::*; + +#[cfg(feature = "shuttle")] +mod shuttle_spawner { + use shuttle_crate::future::{self, JoinError, JoinHandle}; + + use super::*; + + /// Spawner implementation for Shuttle framework to run tests in parallel + pub(super) struct ShuttleSpawner; + + unsafe impl async_scoped::spawner::Spawner for ShuttleSpawner + where + T: Send + 'static, + { + type FutureOutput = Result; + type SpawnHandle = JoinHandle; + + fn spawn + Send + 'static>(&self, f: F) -> Self::SpawnHandle { + future::spawn(f) + } + } + + unsafe impl async_scoped::spawner::Blocker for ShuttleSpawner { + fn block_on>(&self, f: F) -> T { + future::block_on(f) + } + } +} + +#[cfg(feature = "shuttle")] +type Spawner<'fut, T> = async_scoped::Scope<'fut, T, shuttle_spawner::ShuttleSpawner>; +#[cfg(not(feature = "shuttle"))] +type Spawner<'fut, T> = async_scoped::TokioScope<'fut, T>; + +unsafe fn create_spawner<'fut, T: Send + 'static>() -> Spawner<'fut, T> { + #[cfg(feature = "shuttle")] + return async_scoped::Scope::create(shuttle_spawner::ShuttleSpawner); + #[cfg(not(feature = "shuttle"))] + return async_scoped::TokioScope::create(async_scoped::spawner::use_tokio::Tokio); +} + +#[pin_project] +#[must_use = "Futures do nothing unless polled"] +pub struct SequentialFutures<'fut, S, F> +where + S: Stream + Send + 'fut, + F: IntoFuture, + <::IntoFuture as Future>::Output: Send + 'static, +{ + #[pin] + spawner: Spawner<'fut, F::Output>, + #[pin] + source: Fuse, + capacity: usize, +} + +impl SequentialFutures<'_, S, F> +where + S: Stream + Send, + F: IntoFuture, + <::IntoFuture as Future>::Output: Send + 'static, +{ + pub unsafe fn new(active: NonZeroUsize, source: S) -> Self { + SequentialFutures { + spawner: unsafe { create_spawner() }, + source: source.fuse(), + capacity: active.get(), + } + } +} + +impl<'fut, S, F> Stream for SequentialFutures<'fut, S, F> +where + S: Stream + Send, + F: IntoFuture, + ::IntoFuture: Send + 'fut, + <::IntoFuture as Future>::Output: Send + 'static, +{ + type Item = F::Output; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + + // Draw more values from the input, up to the capacity. + while this.spawner.remaining() < *this.capacity { + if let Poll::Ready(Some(f)) = this.source.as_mut().poll_next(cx) { + // Making futures cancellable is critical to avoid hangs. + // if one of them panics, unwinding causes spawner to drop and, in turn, + // it blocks the thread to await all pending futures completion. If there is + // a dependency between futures, pending one will never complete. + // Cancellable futures will be cancelled when spawner is dropped which is + // the behavior we want. + let task_index = this.spawner.len(); + this.spawner + .spawn_cancellable(f.into_future().instrument(Span::current()), move || { + panic!("SequentialFutures: spawned task {task_index} cancelled") + }); + } else { + break; + } + } + + // Poll spawner if it has work to do. If both source and spawner are empty, we're done. + if this.spawner.remaining() > 0 { + this.spawner.as_mut().poll_next(cx).map(|v| match v { + Some(Ok(v)) => Some(v), + Some(Err(_)) => panic!("SequentialFutures: spawned task aborted"), + None => None, + }) + } else if this.source.is_done() { + Poll::Ready(None) + } else { + Poll::Pending + } + } + + fn size_hint(&self) -> (usize, Option) { + let in_progress = self.spawner.remaining(); + let (lower, upper) = self.source.size_hint(); + ( + lower.saturating_add(in_progress), + upper.and_then(|u| u.checked_add(in_progress)), + ) + } +} + +pub(super) unsafe fn parallel_join<'fut, I, F, O, E>( + iterable: I, +) -> impl Future, E>> + Send + 'fut +where + I: IntoIterator + Send, + F: Future> + Send + 'fut, + O: Send + 'static, + E: Send + 'static, +{ + let mut scope = { + let mut scope = unsafe { create_spawner() }; + for element in iterable { + // it is important to make those cancellable to avoid deadlocks if one of the spawned future panics. + // If there is a dependency between futures, pending one will never complete. + // Cancellable futures will be cancelled when spawner is dropped which is the behavior we want. + scope.spawn_cancellable(element.instrument(Span::current()), || { + panic!("parallel_join: task cancelled") + }); + } + scope + }; + + async move { + let mut result = Vec::with_capacity(scope.len()); + while let Some(item) = scope.next().await { + // join error is nothing we can do about + result.push(item.expect("parallel_join: received JoinError")?) + } + Ok(result) + } +} + +#[cfg(all(test, unit_test))] +mod tests { + use std::{future::Future, pin::Pin}; + + use crate::test_executor::run; + + /// This test demonstrates that forgetting the future returned by `parallel_join` is not safe and will cause + /// use-after-free safety error. It spawns a few tasks that constantly try to access the `borrow_from_me` weak + /// reference while the main thread drops the owning reference. By proving that futures are able to see the weak + /// pointer unset, this test shows that same can happen for regular references and cause use-after-free. + #[test] + fn parallel_join_forget_is_not_safe() { + use futures::future::poll_immediate; + + use crate::{seq_join::multi_thread::parallel_join, sync::Arc}; + + run(|| async { + const N: usize = 24; + let borrowed_vec = Box::new([1, 2, 3]); + let borrow_from_me = Arc::new(vec![1, 2, 3]); + let start = Arc::new(tokio::sync::Barrier::new(N + 1)); + // counts how many tasks have accessed `borrow_from_me` after it was destroyed. + // this test expects all tasks to access `borrow_from_me` at least once. + let bad_accesses = Arc::new(tokio::sync::Barrier::new(N + 1)); + + let futures = (0..N) + .map(|_| { + let borrowed = Arc::downgrade(&borrow_from_me); + let regular_ref = &borrowed_vec; + let start = start.clone(); + let bad_access = bad_accesses.clone(); + async move { + start.wait().await; + for _ in 0..100 { + if borrowed.upgrade().is_none() { + bad_access.wait().await; + // switch to `true` if you want to see the real corruption. + #[allow(unreachable_code)] + if false { + // this is a place where we can see the use-after-free. + // we avoid executing this block to appease sanitizers, but compiler happily + // allows us to follow this reference. + println!("{:?}", regular_ref); + } + break; + } + tokio::task::yield_now().await; + } + Ok::<_, ()>(()) + } + }) + .collect::>(); + + let mut f = Box::pin(unsafe { parallel_join(futures) }); + poll_immediate(&mut f).await; + start.wait().await; + + // the type of `f` above captures the lifetime for borrowed_vec. Leaking `f` allows `borrowed_vec` to be + // dropped, but that drop prohibits any subsequent manipulations with `f` pointer, irrespective of whether + // `f` is `&mut _` or `*mut _` (value already borrowed error). + // I am not sure I fully understand what is going on here (why borrowck allows me to leak the value, but + // then I can't drop it even if it is a raw pointer), but removing the lifetime from `f` type allows + // the test to pass. + // + // This is only required to do the proper cleanup and avoid memory leaks. Replacing this line with + // `mem::forget(f)` will lead to the same test outcome, but Miri will complain about memory leaks. + let f: _ = unsafe { + std::mem::transmute::<_, Pin, ()>> + Send>>>( + Box::pin(f) as Pin, ()>>>>, + ) + }; + + // Async executor will still be polling futures and they will try to follow this pointer. + drop(borrow_from_me); + drop(borrowed_vec); + + // this test should terminate because all tasks should access `borrow_from_me` at least once. + bad_accesses.wait().await; + + drop(f); + }); + } +} From 28c945e519539d138ad87c7f4deff91d0a96fd5e Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Thu, 25 Jan 2024 09:59:40 -0800 Subject: [PATCH 27/27] Import async-scoped from crates.io --- ipa-core/Cargo.toml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/ipa-core/Cargo.toml b/ipa-core/Cargo.toml index 9cc7c1dbe..a28752559 100644 --- a/ipa-core/Cargo.toml +++ b/ipa-core/Cargo.toml @@ -75,8 +75,7 @@ ipa-macros = { version = "*", path = "../ipa-macros" } aes = "0.8.3" async-trait = "0.1.68" -# TODO: migrate to crates.io once 0.9 is released: https://github.com/rmanoka/async-scoped/issues/27 -async-scoped = { git = "https://github.com/rmanoka/async-scoped.git", features = ["use-tokio"], optional = true } +async-scoped = { version = "0.9.0", features = ["use-tokio"], optional = true } axum = { version = "0.5.17", optional = true, features = ["http2"] } axum-server = { version = "0.5.1", optional = true, features = [ "rustls",