diff --git a/.clippy.toml b/.clippy.toml index 238592639..5c572f532 100644 --- a/.clippy.toml +++ b/.clippy.toml @@ -2,4 +2,8 @@ disallowed-methods = [ { path = "futures::future::join_all", reason = "We don't have a replacement for this method yet. Consider extending `SeqJoin` trait." }, { path = "futures::future::try_join_all", reason = "Use Context.try_join instead." }, + { path = "std::boxed::Box::leak", reason = "Not running the destructors on futures created inside seq_join module will cause UB in IPA. Make sure you don't leak any of those." }, + { path = "std::mem::forget", reason = "Not running the destructors on futures created inside seq_join module will cause UB in IPA. Make sure you don't leak any of those." }, + { path = "std::mem::ManuallyDrop::new", reason = "Not running the destructors on futures created inside seq_join module will cause UB in IPA. Make sure you don't leak any of those." }, + { path = "std::vec::Vec::leak", reason = "Not running the destructors on futures created inside seq_join module will cause UB in IPA. Make sure you don't leak any of those." }, ] diff --git a/.github/workflows/check.yml b/.github/workflows/check.yml index aba7c292a..6cc659bd0 100644 --- a/.github/workflows/check.yml +++ b/.github/workflows/check.yml @@ -63,9 +63,12 @@ jobs: if: ${{ success() || failure() }} run: cargo build --tests - - name: Run Tests + - name: Run tests run: cargo test + - name: Run tests with multithreading feature enabled + run: cargo test --features "multi-threading" + - name: Run Web Tests run: cargo test -p ipa-core --no-default-features --features "cli web-app real-world-infra test-fixture descriptive-gate" @@ -96,10 +99,10 @@ jobs: run: cargo build --release - name: Build concurrency tests - run: cargo build --release --features shuttle + run: cargo build --release --features "shuttle multi-threading" - name: Run concurrency tests - run: cargo test --release --features shuttle + run: cargo test --release --features "shuttle multi-threading" extra: name: Additional Builds and Concurrency Tests @@ -148,6 +151,7 @@ jobs: fail-fast: false matrix: sanitizer: [address, leak] + features: ['', 'multi-threading'] env: TARGET: x86_64-unknown-linux-gnu steps: @@ -156,7 +160,21 @@ jobs: - name: Add Rust sources run: rustup component add rust-src - name: Run tests with sanitizer - run: RUSTFLAGS="-Z sanitizer=${{ matrix.sanitizer }} -Z sanitizer-memory-track-origins" cargo test -Z build-std --target $TARGET --no-default-features --features "cli web-app real-world-infra test-fixture descriptive-gate" + run: RUSTFLAGS="-Z sanitizer=${{ matrix.sanitizer }} -Z sanitizer-memory-track-origins" cargo test -Z build-std --target $TARGET --no-default-features --features "cli web-app real-world-infra test-fixture descriptive-gate ${{ matrix.features }}" + + miri: + runs-on: ubuntu-latest + env: + TARGET: x86_64-unknown-linux-gnu + steps: + - uses: actions/checkout@v3 + - uses: dtolnay/rust-toolchain@nightly + - name: Add Miri + run: rustup component add miri + - name: Setup Miri + run: cargo miri setup + - name: Run seq_join tests + run: cargo miri test --target $TARGET --lib seq_join --features "multi-threading" coverage: name: Measure coverage diff --git a/.pre-commit.stashsIsbN1 b/.pre-commit.stashsIsbN1 new file mode 100644 index 000000000..e69de29bb diff --git a/ipa-core/Cargo.toml b/ipa-core/Cargo.toml index 08d7314c6..a28752559 100644 --- a/ipa-core/Cargo.toml +++ b/ipa-core/Cargo.toml @@ -61,6 +61,8 @@ step-trace = ["descriptive-gate"] # of unit tests use it. Compact uses memory-efficient gates and is suitable for production. descriptive-gate = [] compact-gate = ["ipa-macros/compact-gate"] +# Enable using more than one thread for protocol execution. Most of the parallelism occurs at parallel/seq_join operations +multi-threading = ["async-scoped"] # Standalone aggregation protocol. We use IPA infra for communication # but it has nothing to do with IPA. @@ -73,6 +75,7 @@ ipa-macros = { version = "*", path = "../ipa-macros" } aes = "0.8.3" async-trait = "0.1.68" +async-scoped = { version = "0.9.0", features = ["use-tokio"], optional = true } axum = { version = "0.5.17", optional = true, features = ["http2"] } axum-server = { version = "0.5.1", optional = true, features = [ "rustls", @@ -132,7 +135,7 @@ sha2 = "0.10" shuttle-crate = { package = "shuttle", version = "0.6.1", optional = true } thiserror = "1.0" time = { version = "0.3", optional = true } -tokio = { version = "1.28", features = ["fs", "rt", "rt-multi-thread", "macros"] } +tokio = { version = "1.35", features = ["fs", "rt", "rt-multi-thread", "macros"] } # TODO: axum-server holds onto 0.24 and we can't upgrade until they do. Or we move away from axum-server tokio-rustls = { version = "0.24", optional = true } tokio-stream = "0.1.14" diff --git a/ipa-core/src/lib.rs b/ipa-core/src/lib.rs index 0e4f63828..b5fb7924a 100644 --- a/ipa-core/src/lib.rs +++ b/ipa-core/src/lib.rs @@ -112,16 +112,19 @@ pub(crate) mod test_executor { run(f); } - pub fn run(f: F) + pub fn run(f: F) -> T where F: Fn() -> Fut + Send + Sync + 'static, - Fut: Future, + Fut: Future, { tokio::runtime::Builder::new_multi_thread() - .enable_all() + // enable_all() is common to use to build Tokio runtime, but it enables both IO and time drivers. + // IO driver is not compatible with Miri (https://github.com/rust-lang/miri/issues/2057) which we use to + // sanitize our tests, so this runtime only enables time driver. + .enable_time() .build() .unwrap() - .block_on(f()); + .block_on(f()) } } diff --git a/ipa-core/src/protocol/basics/reshare.rs b/ipa-core/src/protocol/basics/reshare.rs index 2e9a868e6..70b65c1c3 100644 --- a/ipa-core/src/protocol/basics/reshare.rs +++ b/ipa-core/src/protocol/basics/reshare.rs @@ -42,7 +42,7 @@ use crate::{ /// `to_helper` = (`rand_left`, `rand_right`) = (r0, r1) /// `to_helper.right` = (`rand_right`, part1 + part2) = (r0, part1 + part2) #[async_trait] -pub trait Reshare: Sized { +pub trait Reshare: Sized + 'static { async fn reshare<'fut>( &self, ctx: C, diff --git a/ipa-core/src/protocol/ipa/mod.rs b/ipa-core/src/protocol/ipa/mod.rs index b2509b80f..e7030e2c3 100644 --- a/ipa-core/src/protocol/ipa/mod.rs +++ b/ipa-core/src/protocol/ipa/mod.rs @@ -466,7 +466,10 @@ where .collect::>() } -#[cfg(all(test, any(unit_test, feature = "shuttle")))] +#[cfg(all( + test, + any(unit_test, all(feature = "shuttle", not(feature = "multi-threading"))) +))] pub mod tests { use std::num::NonZeroU32; diff --git a/ipa-core/src/protocol/modulus_conversion/convert_shares.rs b/ipa-core/src/protocol/modulus_conversion/convert_shares.rs index ce2a6a369..ac6aabf33 100644 --- a/ipa-core/src/protocol/modulus_conversion/convert_shares.rs +++ b/ipa-core/src/protocol/modulus_conversion/convert_shares.rs @@ -299,17 +299,17 @@ where /// # Panics /// If the total record count on the context is unspecified. #[tracing::instrument(name = "modulus_conversion", skip_all, fields(bits = ?bit_range, gate = %ctx.gate().as_ref()))] -pub fn convert_bits( +pub fn convert_bits<'a, F, V, C, S, VS>( ctx: C, binary_shares: VS, bit_range: Range, -) -> impl Stream, Error>> +) -> impl Stream, Error>> + 'a where F: PrimeField, - V: ToBitConversionTriples, - C: UpgradedContext, + V: ToBitConversionTriples + 'a, + C: UpgradedContext + 'a, S: LinearSecretSharing + SecureMul, - VS: Stream + Unpin + Send, + VS: Stream + Unpin + Send + 'a, for<'u> UpgradeContext<'u, C, F, RecordId>: UpgradeToMalicious<'u, BitConversionTriple>, BitConversionTriple>, { @@ -320,35 +320,37 @@ where /// Note that unconverted fields are not upgraded, so they might need to be upgraded either before or /// after invoking this function. #[tracing::instrument(name = "modulus_conversion", skip_all, fields(bits = ?bit_range, gate = %ctx.gate().as_ref()))] -pub fn convert_selected_bits( +pub fn convert_selected_bits<'a, F, V, C, S, VS, R>( ctx: C, binary_shares: VS, bit_range: Range, -) -> impl Stream, R), Error>> +) -> impl Stream, R), Error>> + 'a where + R: Send + 'static, F: PrimeField, - V: ToBitConversionTriples, - C: UpgradedContext, + V: ToBitConversionTriples + 'a, + C: UpgradedContext + 'a, S: LinearSecretSharing + SecureMul, - VS: Stream + Unpin + Send, + VS: Stream + Unpin + Send + 'a, for<'u> UpgradeContext<'u, C, F, RecordId>: UpgradeToMalicious<'u, BitConversionTriple>, BitConversionTriple>, { convert_some_bits(ctx, binary_shares, RecordId::FIRST, bit_range) } -pub(crate) fn convert_some_bits( +pub(crate) fn convert_some_bits<'a, F, V, C, S, VS, R>( ctx: C, binary_shares: VS, first_record: RecordId, bit_range: Range, -) -> impl Stream, R), Error>> +) -> impl Stream, R), Error>> + 'a where + R: Send + 'static, F: PrimeField, - V: ToBitConversionTriples, - C: UpgradedContext, + V: ToBitConversionTriples + 'a, + C: UpgradedContext + 'a, S: LinearSecretSharing + SecureMul, - VS: Stream + Unpin + Send, + VS: Stream + Unpin + Send + 'a, for<'u> UpgradeContext<'u, C, F, RecordId>: UpgradeToMalicious<'u, BitConversionTriple>, BitConversionTriple>, { diff --git a/ipa-core/src/protocol/sort/generate_permutation_opt.rs b/ipa-core/src/protocol/sort/generate_permutation_opt.rs index b82e05a80..22d2eed1b 100644 --- a/ipa-core/src/protocol/sort/generate_permutation_opt.rs +++ b/ipa-core/src/protocol/sort/generate_permutation_opt.rs @@ -298,7 +298,9 @@ mod tests { } /// Passing 32 records for Fp31 doesn't work. - #[tokio::test] + /// + /// Requires one extra thread to cancel futures running in parallel with the one that panics. + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] #[should_panic = "prime field ipa_core::ff::prime_field::fp31::Fp31 is too small to sort 32 records"] async fn fp31_overflow() { const COUNT: usize = 32; diff --git a/ipa-core/src/secret_sharing/replicated/malicious/additive_share.rs b/ipa-core/src/secret_sharing/replicated/malicious/additive_share.rs index 891f07c0b..59b2fc0ca 100644 --- a/ipa-core/src/secret_sharing/replicated/malicious/additive_share.rs +++ b/ipa-core/src/secret_sharing/replicated/malicious/additive_share.rs @@ -77,7 +77,7 @@ impl LinearSecretSharing for AdditiveShare< /// when the protocol is done. This should not be used directly. #[async_trait] pub trait Downgrade: Send { - type Target: Send; + type Target: Send + 'static; async fn downgrade(self) -> UnauthorizedDowngradeWrapper; } diff --git a/ipa-core/src/secret_sharing/scheme.rs b/ipa-core/src/secret_sharing/scheme.rs index c4f5256bc..def4b6664 100644 --- a/ipa-core/src/secret_sharing/scheme.rs +++ b/ipa-core/src/secret_sharing/scheme.rs @@ -7,7 +7,7 @@ use super::SharedValue; use crate::ff::{AddSub, AddSubAssign, Field, GaloisField}; /// Secret sharing scheme i.e. Replicated secret sharing -pub trait SecretSharing: Clone + Debug + Sized + Send + Sync { +pub trait SecretSharing: Clone + Debug + Sized + Send + Sync + 'static { const ZERO: Self; } @@ -21,6 +21,7 @@ pub trait Linear: + Mul + for<'r> Mul<&'r V, Output = Self> + Neg + + 'static { } diff --git a/ipa-core/src/seq_join.rs b/ipa-core/src/seq_join/local.rs similarity index 53% rename from ipa-core/src/seq_join.rs rename to ipa-core/src/seq_join/local.rs index e3cb8c5b3..33fe3d757 100644 --- a/ipa-core/src/seq_join.rs +++ b/ipa-core/src/seq_join/local.rs @@ -1,116 +1,15 @@ use std::{ collections::VecDeque, future::IntoFuture, + marker::PhantomData, num::NonZeroUsize, pin::Pin, task::{Context, Poll}, }; -use futures::{ - stream::{iter, Iter as StreamIter, TryCollect}, - Future, Stream, StreamExt, TryStreamExt, -}; -use pin_project::pin_project; - -use crate::exact::ExactSizeStream; - -/// This helper function might be necessary to convince the compiler that -/// the return value from [`seq_try_join_all`] implements `Send`. -/// Use this if you get higher-ranked lifetime errors that mention `std::marker::Send`. -/// -/// -pub fn assert_send<'a, O>( - fut: impl Future + Send + 'a, -) -> impl Future + Send + 'a { - fut -} - -/// Sequentially join futures from a stream. -/// -/// This function polls futures in strict sequence. -/// If any future blocks, up to `active - 1` futures after it will be polled so -/// that they make progress. -/// -/// # Deadlocks -/// -/// This will fail to resolve if the progress of any future depends on a future more -/// than `active` items behind it in the input sequence. -/// -/// [`try_join_all`]: futures::future::try_join_all -/// [`Stream`]: futures::stream::Stream -/// [`StreamExt::buffered`]: futures::stream::StreamExt::buffered -pub fn seq_join(active: NonZeroUsize, source: S) -> SequentialFutures -where - S: Stream + Send, - F: Future, -{ - SequentialFutures { - source: source.fuse(), - active: VecDeque::with_capacity(active.get()), - } -} - -/// The `SeqJoin` trait wraps `seq_try_join_all`, providing the `active` parameter -/// from the provided context so that the value can be made consistent. -pub trait SeqJoin { - /// Perform a sequential join of the futures from the provided iterable. - /// This uses [`seq_join`], with the current state of the associated object - /// being used to determine the number of active items to track (see [`active_work`]). - /// - /// A rough rule of thumb for how to decide between this and [`parallel_join`] is - /// that this should be used whenever you are iterating over different records. - /// [`parallel_join`] is better suited to smaller batches, such as iterating over - /// the bits of a value for a single record. - /// - /// Note that the join functions from the [`futures`] crate, such as [`join3`], - /// are also parallel and can be used where you have a small, fixed number of tasks. - /// - /// Be especially careful if you use the random bits generator with this. - /// The random bits generator can produce values out of sequence. - /// You might need to use [`parallel_join`] for that. - /// - /// [`active_work`]: Self::active_work - /// [`parallel_join`]: Self::parallel_join - /// [`join3`]: futures::future::join3 - fn try_join(&self, iterable: I) -> TryCollect, Vec> - where - I: IntoIterator + Send, - I::IntoIter: Send, - F: Future>, - { - seq_try_join_all(self.active_work(), iterable) - } - - /// Join multiple tasks in parallel. Only do this if you can't use a sequential join. - fn parallel_join(&self, iterable: I) -> futures::future::TryJoinAll - where - I: IntoIterator, - I::Item: futures::future::TryFuture, - { - #[allow(clippy::disallowed_methods)] // Just in this one place. - futures::future::try_join_all(iterable) - } - - /// The amount of active work that is concurrently permitted. - fn active_work(&self) -> NonZeroUsize; -} +use futures::stream::Fuse; -type SeqTryJoinAll = SequentialFutures::IntoIter>, F>; - -/// A substitute for [`futures::future::try_join_all`] that uses [`seq_join`]. -/// This awaits all the provided futures in order, -/// aborting early if any future returns `Result::Err`. -pub fn seq_try_join_all( - active: NonZeroUsize, - source: I, -) -> TryCollect, Vec> -where - I: IntoIterator + Send, - I::IntoIter: Send, - F: Future>, -{ - seq_join(active, iter(source)).try_collect() -} +use super::*; enum ActiveItem { Pending(Pin>), @@ -142,7 +41,7 @@ impl ActiveItem { #[must_use] fn take(self) -> F::Output { let ActiveItem::Resolved(v) = self else { - panic!("No value to take out"); + unreachable!("take should be only called once."); }; v @@ -150,17 +49,32 @@ impl ActiveItem { } #[pin_project] -pub struct SequentialFutures +pub struct SequentialFutures<'unused, S, F> where S: Stream + Send, F: IntoFuture, { #[pin] - source: futures::stream::Fuse, + source: Fuse, active: VecDeque>, + _marker: PhantomData &'unused ()>, } -impl Stream for SequentialFutures +impl SequentialFutures<'_, S, F> +where + S: Stream + Send, + F: IntoFuture, +{ + pub fn new(active: NonZeroUsize, source: S) -> Self { + Self { + source: source.fuse(), + active: VecDeque::with_capacity(active.get()), + _marker: PhantomData, + } + } +} + +impl Stream for SequentialFutures<'_, S, F> where S: Stream + Send, F: IntoFuture, @@ -207,18 +121,9 @@ where } } -impl ExactSizeStream for SequentialFutures -where - S: Stream + Send + ExactSizeStream, - F: IntoFuture, -{ -} - #[cfg(all(test, unit_test))] -mod test { +mod local_test { use std::{ - convert::Infallible, - iter::once, num::NonZeroUsize, ptr::null, sync::{Arc, Mutex}, @@ -226,78 +131,13 @@ mod test { }; use futures::{ - future::{lazy, BoxFuture}, - stream::{iter, poll_fn, poll_immediate, repeat_with}, - Future, StreamExt, + future::lazy, + stream::{poll_fn, repeat_with}, + StreamExt, }; - use crate::seq_join::{seq_join, seq_try_join_all}; - - async fn immediate(count: u32) { - let capacity = NonZeroUsize::new(3).unwrap(); - let values = seq_join(capacity, iter((0..count).map(|i| async move { i }))) - .collect::>() - .await; - assert_eq!((0..count).collect::>(), values); - } - - #[tokio::test] - async fn within_capacity() { - immediate(2).await; - immediate(1).await; - } - - #[tokio::test] - async fn over_capacity() { - immediate(10).await; - } - - #[tokio::test] - async fn out_of_order() { - let capacity = NonZeroUsize::new(3).unwrap(); - let barrier = tokio::sync::Barrier::new(2); - let unresolved: BoxFuture<'_, u32> = Box::pin(async { - barrier.wait().await; - 0 - }); - let it = once(unresolved) - .chain((1..4_u32).map(|i| -> BoxFuture<'_, u32> { Box::pin(async move { i }) })); - let mut seq_futures = seq_join(capacity, iter(it)); - - assert_eq!( - Some(Poll::Pending), - poll_immediate(&mut seq_futures).next().await - ); - barrier.wait().await; - assert_eq!(vec![0, 1, 2, 3], seq_futures.collect::>().await); - } - - #[tokio::test] - async fn join_success() { - fn f(v: T) -> impl Future> { - lazy(move |_| Ok(v)) - } - - let active = NonZeroUsize::new(10).unwrap(); - let res = seq_try_join_all(active, (1..5).map(f)).await.unwrap(); - assert_eq!((1..5).collect::>(), res); - } - - #[tokio::test] - async fn try_join_early_abort() { - const ERROR: &str = "error message"; - fn f(i: u32) -> impl Future> { - lazy(move |_| match i { - 1 => Ok(1), - 2 => Err(ERROR), - _ => panic!("should have aborted earlier"), - }) - } - - let active = NonZeroUsize::new(10).unwrap(); - let err = seq_try_join_all(active, (1..=3).map(f)).await.unwrap_err(); - assert_eq!(err, ERROR); - } + use super::*; + use crate::test_executor::run; fn fake_waker() -> Waker { use std::task::{RawWaker, RawWakerVTable}; @@ -407,4 +247,22 @@ mod test { assert_count(&produced_r, 0); assert!(matches!(res, Poll::Ready(None))); } + + #[test] + fn try_join_early_abort() { + const ERROR: &str = "error message"; + fn f(i: u32) -> impl Future> { + lazy(move |_| match i { + 1 => Ok(1), + 2 => Err(ERROR), + _ => panic!("should have aborted earlier"), + }) + } + + run(|| async { + let active = NonZeroUsize::new(10).unwrap(); + let err = seq_try_join_all(active, (1..=3).map(f)).await.unwrap_err(); + assert_eq!(err, ERROR); + }); + } } diff --git a/ipa-core/src/seq_join/mod.rs b/ipa-core/src/seq_join/mod.rs new file mode 100644 index 000000000..dfe6b1073 --- /dev/null +++ b/ipa-core/src/seq_join/mod.rs @@ -0,0 +1,274 @@ +use std::{future::IntoFuture, num::NonZeroUsize}; + +use futures::{ + stream::{iter, Iter as StreamIter, TryCollect}, + Future, Stream, StreamExt, TryStreamExt, +}; +use pin_project::pin_project; + +use crate::exact::ExactSizeStream; + +#[cfg(not(feature = "multi-threading"))] +mod local; +#[cfg(feature = "multi-threading")] +mod multi_thread; + +/// This helper function might be necessary to convince the compiler that +/// the return value from [`seq_try_join_all`] implements `Send`. +/// Use this if you get higher-ranked lifetime errors that mention `std::marker::Send`. +/// +/// +pub fn assert_send<'a, O>( + fut: impl Future + Send + 'a, +) -> impl Future + Send + 'a { + fut +} + +/// Sequentially join futures from a stream. +/// +/// This function polls futures in strict sequence. +/// If any future blocks, up to `active - 1` futures after it will be polled so +/// that they make progress. +/// +/// # Deadlocks +/// +/// This will fail to resolve if the progress of any future depends on a future more +/// than `active` items behind it in the input sequence. +/// +/// # Safety +/// If multi-threading is enabled, forgetting the resulting future will cause use-after-free error. Do not leak it or +/// prevent the future destructor from running. +/// +/// [`try_join_all`]: futures::future::try_join_all +/// [`Stream`]: futures::stream::Stream +/// [`StreamExt::buffered`]: futures::stream::StreamExt::buffered +pub fn seq_join<'st, S, F, O>(active: NonZeroUsize, source: S) -> SequentialFutures<'st, S, F> +where + S: Stream + Send + 'st, + F: Future + Send, + O: Send + 'static, +{ + #[cfg(feature = "multi-threading")] + unsafe { + SequentialFutures::new(active, source) + } + #[cfg(not(feature = "multi-threading"))] + SequentialFutures::new(active, source) +} + +/// The `SeqJoin` trait wraps `seq_try_join_all`, providing the `active` parameter +/// from the provided context so that the value can be made consistent. +pub trait SeqJoin { + /// Perform a sequential join of the futures from the provided iterable. + /// This uses [`seq_join`], with the current state of the associated object + /// being used to determine the number of active items to track (see [`active_work`]). + /// + /// A rough rule of thumb for how to decide between this and [`parallel_join`] is + /// that this should be used whenever you are iterating over different records. + /// [`parallel_join`] is better suited to smaller batches, such as iterating over + /// the bits of a value for a single record. + /// + /// Note that the join functions from the [`futures`] crate, such as [`join3`], + /// are also parallel and can be used where you have a small, fixed number of tasks. + /// + /// Be especially careful if you use the random bits generator with this. + /// The random bits generator can produce values out of sequence. + /// You might need to use [`parallel_join`] for that. + /// + /// [`active_work`]: Self::active_work + /// [`parallel_join`]: Self::parallel_join + /// [`join3`]: futures::future::join3 + fn try_join<'fut, I, F, O, E>( + &self, + iterable: I, + ) -> TryCollect, Vec> + where + I: IntoIterator + Send, + I::IntoIter: Send + 'fut, + F: Future> + Send + 'fut, + O: Send + 'static, + E: Send + 'static, + { + seq_try_join_all(self.active_work(), iterable) + } + + /// Join multiple tasks in parallel. Only do this if you can't use a sequential join. + /// + /// # Safety + /// Forgetting the future returned from this function will cause use-after-free. This is a tradeoff between + /// performance and safety that allows us to use regular references instead of Arc pointers. + /// + /// Dropping the future is always safe. + #[cfg(feature = "multi-threading")] + fn parallel_join<'a, I, F, O, E>( + &self, + iterable: I, + ) -> std::pin::Pin, E>> + Send + 'a>> + where + I: IntoIterator + Send, + F: Future> + Send + 'a, + O: Send + 'static, + E: Send + 'static, + { + unsafe { Box::pin(multi_thread::parallel_join(iterable)) } + } + + /// Join multiple tasks in parallel. Only do this if you can't use a sequential join. + #[cfg(not(feature = "multi-threading"))] + fn parallel_join(&self, iterable: I) -> futures::future::TryJoinAll + where + I: IntoIterator, + I::Item: futures::future::TryFuture, + { + #[allow(clippy::disallowed_methods)] // Just in this one place. + futures::future::try_join_all(iterable) + } + + /// The amount of active work that is concurrently permitted. + fn active_work(&self) -> NonZeroUsize; +} + +type SeqTryJoinAll<'st, I, F> = + SequentialFutures<'st, StreamIter<::IntoIter>, F>; + +/// A substitute for [`futures::future::try_join_all`] that uses [`seq_join`]. +/// This awaits all the provided futures in order, +/// aborting early if any future returns `Result::Err`. +pub fn seq_try_join_all<'iter, I, F, O, E>( + active: NonZeroUsize, + source: I, +) -> TryCollect, Vec> +where + I: IntoIterator + Send, + I::IntoIter: Send + 'iter, + F: Future> + Send + 'iter, + O: Send + 'static, + E: Send + 'static, +{ + seq_join(active, iter(source)).try_collect() +} + +impl<'fut, S, F> ExactSizeStream for SequentialFutures<'fut, S, F> +where + S: Stream + Send + ExactSizeStream, + F: IntoFuture, + ::IntoFuture: Send + 'fut, + <::IntoFuture as Future>::Output: Send + 'static, +{ +} + +#[cfg(not(feature = "multi-threading"))] +pub use local::SequentialFutures; +#[cfg(feature = "multi-threading")] +pub use multi_thread::SequentialFutures; + +#[cfg(all(test, any(unit_test, feature = "shuttle")))] +mod test { + use std::{convert::Infallible, iter::once, task::Poll}; + + use futures::{ + future::{lazy, BoxFuture}, + stream::{iter, poll_immediate}, + Future, StreamExt, + }; + + use super::*; + use crate::test_executor::run; + + async fn immediate(count: u32) { + let capacity = NonZeroUsize::new(3).unwrap(); + let values = seq_join(capacity, iter((0..count).map(|i| async move { i }))) + .collect::>() + .await; + assert_eq!((0..count).collect::>(), values); + } + + #[test] + fn within_capacity() { + run(|| async { + immediate(2).await; + immediate(1).await; + }); + } + + #[test] + fn over_capacity() { + run(|| async { + immediate(10).await; + }); + } + + #[test] + fn size() { + run(|| async { + let mut count = 10_usize; + let capacity = NonZeroUsize::new(3).unwrap(); + let mut values = seq_join(capacity, iter((0..count).map(|i| async move { i }))); + assert_eq!((count, Some(count)), values.size_hint()); + + while values.next().await.is_some() { + count -= 1; + assert_eq!((count, Some(count)), values.size_hint()); + } + }); + } + + #[test] + fn out_of_order() { + run(|| async { + let capacity = NonZeroUsize::new(3).unwrap(); + let barrier = tokio::sync::Barrier::new(2); + let unresolved: BoxFuture<'_, u32> = Box::pin(async { + barrier.wait().await; + 0 + }); + let it = once(unresolved) + .chain((1..4_u32).map(|i| -> BoxFuture<'_, u32> { Box::pin(async move { i }) })); + let mut seq_futures = seq_join(capacity, iter(it)); + + assert_eq!( + Some(Poll::Pending), + poll_immediate(&mut seq_futures).next().await + ); + barrier.wait().await; + assert_eq!(vec![0, 1, 2, 3], seq_futures.collect::>().await); + }); + } + + #[test] + fn join_success() { + fn f(v: T) -> impl Future> { + lazy(move |_| Ok(v)) + } + + run(|| async { + let active = NonZeroUsize::new(10).unwrap(); + let res = seq_try_join_all(active, (1..5).map(f)).await.unwrap(); + assert_eq!((1..5).collect::>(), res); + }); + } + + #[test] + #[cfg_attr( + all(feature = "shuttle", feature = "multi-threading"), + should_panic(expected = "cancelled") + )] + fn does_not_block_on_error() { + const ERROR: &str = "returning early is safe"; + use std::pin::Pin; + + fn f(i: u32) -> Pin> + Send>> { + match i { + 1 => Box::pin(lazy(move |_| Ok(1))), + 2 => Box::pin(lazy(move |_| Err(ERROR))), + _ => Box::pin(futures::future::pending()), + } + } + + run(|| async { + let active = NonZeroUsize::new(10).unwrap(); + let err = seq_try_join_all(active, (1..=3).map(f)).await.unwrap_err(); + assert_eq!(err, ERROR); + }); + } +} diff --git a/ipa-core/src/seq_join/multi_thread.rs b/ipa-core/src/seq_join/multi_thread.rs new file mode 100644 index 000000000..79ee89d6e --- /dev/null +++ b/ipa-core/src/seq_join/multi_thread.rs @@ -0,0 +1,252 @@ +use std::{ + future::IntoFuture, + num::NonZeroUsize, + pin::Pin, + task::{Context, Poll}, +}; + +use futures::stream::Fuse; +use tracing::{Instrument, Span}; + +use super::*; + +#[cfg(feature = "shuttle")] +mod shuttle_spawner { + use shuttle_crate::future::{self, JoinError, JoinHandle}; + + use super::*; + + /// Spawner implementation for Shuttle framework to run tests in parallel + pub(super) struct ShuttleSpawner; + + unsafe impl async_scoped::spawner::Spawner for ShuttleSpawner + where + T: Send + 'static, + { + type FutureOutput = Result; + type SpawnHandle = JoinHandle; + + fn spawn + Send + 'static>(&self, f: F) -> Self::SpawnHandle { + future::spawn(f) + } + } + + unsafe impl async_scoped::spawner::Blocker for ShuttleSpawner { + fn block_on>(&self, f: F) -> T { + future::block_on(f) + } + } +} + +#[cfg(feature = "shuttle")] +type Spawner<'fut, T> = async_scoped::Scope<'fut, T, shuttle_spawner::ShuttleSpawner>; +#[cfg(not(feature = "shuttle"))] +type Spawner<'fut, T> = async_scoped::TokioScope<'fut, T>; + +unsafe fn create_spawner<'fut, T: Send + 'static>() -> Spawner<'fut, T> { + #[cfg(feature = "shuttle")] + return async_scoped::Scope::create(shuttle_spawner::ShuttleSpawner); + #[cfg(not(feature = "shuttle"))] + return async_scoped::TokioScope::create(async_scoped::spawner::use_tokio::Tokio); +} + +#[pin_project] +#[must_use = "Futures do nothing unless polled"] +pub struct SequentialFutures<'fut, S, F> +where + S: Stream + Send + 'fut, + F: IntoFuture, + <::IntoFuture as Future>::Output: Send + 'static, +{ + #[pin] + spawner: Spawner<'fut, F::Output>, + #[pin] + source: Fuse, + capacity: usize, +} + +impl SequentialFutures<'_, S, F> +where + S: Stream + Send, + F: IntoFuture, + <::IntoFuture as Future>::Output: Send + 'static, +{ + pub unsafe fn new(active: NonZeroUsize, source: S) -> Self { + SequentialFutures { + spawner: unsafe { create_spawner() }, + source: source.fuse(), + capacity: active.get(), + } + } +} + +impl<'fut, S, F> Stream for SequentialFutures<'fut, S, F> +where + S: Stream + Send, + F: IntoFuture, + ::IntoFuture: Send + 'fut, + <::IntoFuture as Future>::Output: Send + 'static, +{ + type Item = F::Output; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + + // Draw more values from the input, up to the capacity. + while this.spawner.remaining() < *this.capacity { + if let Poll::Ready(Some(f)) = this.source.as_mut().poll_next(cx) { + // Making futures cancellable is critical to avoid hangs. + // if one of them panics, unwinding causes spawner to drop and, in turn, + // it blocks the thread to await all pending futures completion. If there is + // a dependency between futures, pending one will never complete. + // Cancellable futures will be cancelled when spawner is dropped which is + // the behavior we want. + let task_index = this.spawner.len(); + this.spawner + .spawn_cancellable(f.into_future().instrument(Span::current()), move || { + panic!("SequentialFutures: spawned task {task_index} cancelled") + }); + } else { + break; + } + } + + // Poll spawner if it has work to do. If both source and spawner are empty, we're done. + if this.spawner.remaining() > 0 { + this.spawner.as_mut().poll_next(cx).map(|v| match v { + Some(Ok(v)) => Some(v), + Some(Err(_)) => panic!("SequentialFutures: spawned task aborted"), + None => None, + }) + } else if this.source.is_done() { + Poll::Ready(None) + } else { + Poll::Pending + } + } + + fn size_hint(&self) -> (usize, Option) { + let in_progress = self.spawner.remaining(); + let (lower, upper) = self.source.size_hint(); + ( + lower.saturating_add(in_progress), + upper.and_then(|u| u.checked_add(in_progress)), + ) + } +} + +pub(super) unsafe fn parallel_join<'fut, I, F, O, E>( + iterable: I, +) -> impl Future, E>> + Send + 'fut +where + I: IntoIterator + Send, + F: Future> + Send + 'fut, + O: Send + 'static, + E: Send + 'static, +{ + let mut scope = { + let mut scope = unsafe { create_spawner() }; + for element in iterable { + // it is important to make those cancellable to avoid deadlocks if one of the spawned future panics. + // If there is a dependency between futures, pending one will never complete. + // Cancellable futures will be cancelled when spawner is dropped which is the behavior we want. + scope.spawn_cancellable(element.instrument(Span::current()), || { + panic!("parallel_join: task cancelled") + }); + } + scope + }; + + async move { + let mut result = Vec::with_capacity(scope.len()); + while let Some(item) = scope.next().await { + // join error is nothing we can do about + result.push(item.expect("parallel_join: received JoinError")?) + } + Ok(result) + } +} + +#[cfg(all(test, unit_test))] +mod tests { + use std::{future::Future, pin::Pin}; + + use crate::test_executor::run; + + /// This test demonstrates that forgetting the future returned by `parallel_join` is not safe and will cause + /// use-after-free safety error. It spawns a few tasks that constantly try to access the `borrow_from_me` weak + /// reference while the main thread drops the owning reference. By proving that futures are able to see the weak + /// pointer unset, this test shows that same can happen for regular references and cause use-after-free. + #[test] + fn parallel_join_forget_is_not_safe() { + use futures::future::poll_immediate; + + use crate::{seq_join::multi_thread::parallel_join, sync::Arc}; + + run(|| async { + const N: usize = 24; + let borrowed_vec = Box::new([1, 2, 3]); + let borrow_from_me = Arc::new(vec![1, 2, 3]); + let start = Arc::new(tokio::sync::Barrier::new(N + 1)); + // counts how many tasks have accessed `borrow_from_me` after it was destroyed. + // this test expects all tasks to access `borrow_from_me` at least once. + let bad_accesses = Arc::new(tokio::sync::Barrier::new(N + 1)); + + let futures = (0..N) + .map(|_| { + let borrowed = Arc::downgrade(&borrow_from_me); + let regular_ref = &borrowed_vec; + let start = start.clone(); + let bad_access = bad_accesses.clone(); + async move { + start.wait().await; + for _ in 0..100 { + if borrowed.upgrade().is_none() { + bad_access.wait().await; + // switch to `true` if you want to see the real corruption. + #[allow(unreachable_code)] + if false { + // this is a place where we can see the use-after-free. + // we avoid executing this block to appease sanitizers, but compiler happily + // allows us to follow this reference. + println!("{:?}", regular_ref); + } + break; + } + tokio::task::yield_now().await; + } + Ok::<_, ()>(()) + } + }) + .collect::>(); + + let mut f = Box::pin(unsafe { parallel_join(futures) }); + poll_immediate(&mut f).await; + start.wait().await; + + // the type of `f` above captures the lifetime for borrowed_vec. Leaking `f` allows `borrowed_vec` to be + // dropped, but that drop prohibits any subsequent manipulations with `f` pointer, irrespective of whether + // `f` is `&mut _` or `*mut _` (value already borrowed error). + // I am not sure I fully understand what is going on here (why borrowck allows me to leak the value, but + // then I can't drop it even if it is a raw pointer), but removing the lifetime from `f` type allows + // the test to pass. + // + // This is only required to do the proper cleanup and avoid memory leaks. Replacing this line with + // `mem::forget(f)` will lead to the same test outcome, but Miri will complain about memory leaks. + let f: _ = unsafe { + std::mem::transmute::<_, Pin, ()>> + Send>>>( + Box::pin(f) as Pin, ()>>>>, + ) + }; + + // Async executor will still be polling futures and they will try to follow this pointer. + drop(borrow_from_me); + drop(borrowed_vec); + + // this test should terminate because all tasks should access `borrow_from_me` at least once. + bad_accesses.wait().await; + + drop(f); + }); + } +}