diff --git a/ipa-core/src/lib.rs b/ipa-core/src/lib.rs index 547817339..50e2d98d6 100644 --- a/ipa-core/src/lib.rs +++ b/ipa-core/src/lib.rs @@ -74,6 +74,25 @@ pub(crate) mod task { pub use shuttle::future::{JoinError, JoinHandle}; } +#[cfg(feature = "shuttle")] +pub(crate) mod shim { + use std::any::Any; + + use shuttle_crate::future::JoinError; + + /// There is currently an API mismatch between Tokio and Shuttle `JoinError` implementations. + /// This trait brings them closer together, until it is addressed + pub trait Tokio: Sized { + fn try_into_panic(self) -> Result, Self>; + } + + impl Tokio for JoinError { + fn try_into_panic(self) -> Result, Self> { + Err(self) // Shuttle `JoinError` does not wrap panics + } + } +} + #[cfg(not(all(feature = "shuttle", test)))] pub(crate) mod task { pub use tokio::task::{JoinError, JoinHandle}; diff --git a/ipa-core/src/seq_join/multi_thread.rs b/ipa-core/src/seq_join/multi_thread.rs index 0022c7bc3..492dcae9f 100644 --- a/ipa-core/src/seq_join/multi_thread.rs +++ b/ipa-core/src/seq_join/multi_thread.rs @@ -89,6 +89,9 @@ where type Item = F::Output; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + #[cfg(feature = "shuttle")] + use crate::shim::Tokio; + let mut this = self.project(); // Draw more values from the input, up to the capacity. @@ -114,7 +117,13 @@ where if this.spawner.remaining() > 0 { this.spawner.as_mut().poll_next(cx).map(|v| match v { Some(Ok(v)) => Some(v), - Some(Err(_)) => panic!("SequentialFutures: spawned task aborted"), + Some(Err(e)) => { + if let Ok(reason) = e.try_into_panic() { + std::panic::resume_unwind(reason); + } else { + panic!("SequentialFutures: spawned task is cancelled") + } + } None => None, }) } else if this.source.is_done() { @@ -168,9 +177,11 @@ where #[cfg(all(test, unit_test))] mod tests { - use std::{future::Future, pin::Pin}; + use std::{future::Future, num::NonZeroUsize, pin::Pin}; + + use futures_util::future::lazy; - use crate::test_executor::run; + use crate::{seq_join::seq_try_join_all, test_executor::run}; /// This test demonstrates that forgetting the future returned by `parallel_join` is not safe and will cause /// use-after-free safety error. It spawns a few tasks that constantly try to access the `borrow_from_me` weak @@ -248,4 +259,21 @@ mod tests { drop(f); }); } + + #[test] + #[should_panic(expected = "panic in task 1")] + fn panic_from_task_unwinds_to_main() { + fn f(i: u32) -> impl Future> { + lazy(move |_| match i { + 1 => panic!("panic in task 1"), + i => Ok(i), + }) + } + + run(|| async { + let active = NonZeroUsize::new(10).unwrap(); + let _ = seq_try_join_all(active, (1..=3).map(f)).await; + assert!(false, "Should have aborted earlier"); + }); + } }