diff --git a/Cargo.toml b/Cargo.toml index 52244e8..795d80d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,9 +17,11 @@ exclude = ["/.github/*", "/.travis.yml", "/appveyor.yml"] log = { version = "0.4", features = ["std"] } once_cell = "1.9.0" rand = "0.8" +tokio = { version = "1.40.0", features = ["sync", "time", "rt", "macros", "test-util"], optional = true } [features] failpoints = [] +async = ["dep:tokio"] [package.metadata.docs.rs] all-features = true diff --git a/src/lib.rs b/src/lib.rs index 2c5305b..a6caa25 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -230,7 +230,7 @@ use std::env::VarError; use std::fmt::Debug; use std::str::FromStr; use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::{Arc, Condvar, Mutex, MutexGuard, RwLock, TryLockError}; +use std::sync::{Arc, Condvar, Mutex, MutexGuard, RwLock}; use std::time::{Duration, Instant}; use std::{env, thread}; @@ -427,65 +427,93 @@ impl FromStr for Action { } } -#[allow(clippy::mutex_atomic)] #[derive(Debug)] struct FailPoint { - pause: Mutex, - pause_notifier: Condvar, - actions: RwLock>, - actions_str: RwLock, + actions: Mutex, + sync_notifier: Condvar, + #[cfg(feature = "async")] + async_notifier: AsyncNotifier +} + +#[derive(Debug)] +#[cfg(feature = "async")] +struct AsyncNotifier { + tx: tokio::sync::watch::Sender, + rx: tokio::sync::watch::Receiver, +} + +#[derive(Debug)] +struct ConfiguredActions { + seq: u64, + actions_str: String, + actions: Vec +} + +impl ConfiguredActions { + fn empty(seq: u64) -> ConfiguredActions { + ConfiguredActions { + seq, + actions_str: String::new(), + actions: vec![] + } + } +} + +#[cfg(feature = "async")] +impl AsyncNotifier { + fn new() -> AsyncNotifier { + let (tx, rx) = tokio::sync::watch::channel(0); + AsyncNotifier { tx, rx } + } } -#[allow(clippy::mutex_atomic)] impl FailPoint { fn new() -> FailPoint { + let initial_seq: u64 = 0; + let initial_actions = ConfiguredActions::empty(initial_seq); + FailPoint { - pause: Mutex::new(false), - pause_notifier: Condvar::new(), - actions: RwLock::default(), - actions_str: RwLock::default(), + actions: Mutex::new(initial_actions), + sync_notifier: Condvar::new(), + #[cfg(feature = "async")] + async_notifier: AsyncNotifier::new(), } } + fn actions_str(&self) -> String { + let actions_guard = self.actions.lock().unwrap(); + (*actions_guard).actions_str.clone() + } + fn set_actions(&self, actions_str: &str, actions: Vec) { - loop { - // TODO: maybe busy waiting here. - match self.actions.try_write() { - Err(TryLockError::WouldBlock) => {} - Ok(mut guard) => { - *guard = actions; - *self.actions_str.write().unwrap() = actions_str.to_string(); - return; - } - Err(e) => panic!("unexpected poison: {:?}", e), - } - let mut guard = self.pause.lock().unwrap(); - *guard = false; - self.pause_notifier.notify_all(); - } + let mut actions_guard = self.actions.lock().unwrap(); + let next_seq = (*actions_guard).seq + 1; + *actions_guard = ConfiguredActions { + seq: next_seq, + actions_str: actions_str.to_string(), + actions + }; + self.sync_notifier.notify_all(); + #[cfg(feature = "async")] + self.async_notifier.tx.send(next_seq).unwrap(); } #[allow(clippy::option_option)] fn eval(&self, name: &str) -> Option> { - let task = { - let actions = self.actions.read().unwrap(); - match actions.iter().filter_map(Action::get_task).next() { - Some(Task::Pause) => { - let mut guard = self.pause.lock().unwrap(); - *guard = true; - loop { - guard = self.pause_notifier.wait(guard).unwrap(); - if !*guard { - break; - } - } - return None; - } - Some(t) => t, - None => return None, - } - }; + let (task_opt, action_seq) = self.next_task(); + if let Some(task) = task_opt { + self.eval_task(action_seq, name, task) + } else { + None + } + } + fn eval_task( + &self, + action_seq: u64, + name: &str, + task: Task + ) -> Option> { match task { Task::Off => {} Task::Return(s) => return Some(s), @@ -498,7 +526,12 @@ impl FailPoint { Some(ref msg) => log::info!("{}", msg), None => log::info!("failpoint {} executed.", name), }, - Task::Pause => unreachable!(), + Task::Pause => { + let _unused = self.sync_notifier.wait_while( + self.actions.lock().unwrap(), + |guard| { (*guard).seq == action_seq } + ).unwrap(); + }, Task::Yield => thread::yield_now(), Task::Delay(t) => { let timer = Instant::now(); @@ -511,6 +544,61 @@ impl FailPoint { } None } + + #[allow(dead_code)] + #[allow(clippy::option_option)] + #[cfg(feature = "async")] + async fn eval_async(&self, name: &str) -> Option> { + let (task_opt, action_seq) = self.next_task(); + if let Some(task) = task_opt { + self.eval_task_async(action_seq, name, task).await + } else { + None + } + } + + fn next_task(&self) -> (Option, u64){ + let guard = self.actions.lock().unwrap(); + let task = guard.actions.iter().filter_map(Action::get_task).next(); + (task, (*guard).seq) + } + + #[cfg(feature = "async")] + async fn eval_task_async( + &self, + action_seq: u64, + name: &str, + task: Task + ) -> Option> { + match task { + Task::Off => {} + Task::Return(s) => return Some(s), + Task::Sleep(t) => + tokio::time::sleep(Duration::from_millis(t)).await, + Task::Panic(msg) => match msg { + Some(ref msg) => panic!("{}", msg), + None => panic!("failpoint {} panic", name), + }, + Task::Print(msg) => match msg { + Some(ref msg) => log::info!("{}", msg), + None => log::info!("failpoint {} executed.", name), + }, + Task::Pause => { + let mut rx = self.async_notifier.rx.clone(); + rx.wait_for(|val| *val != action_seq).await.unwrap(); + }, + Task::Yield => tokio::task::yield_now().await, + Task::Delay(t) => { + let timer = Instant::now(); + let timeout = Duration::from_millis(t); + while timer.elapsed() < timeout {} + } + Task::Callback(f) => { + f.run(); + } + } + None + } } /// Registry with failpoints configuration. @@ -626,20 +714,40 @@ pub fn list() -> Vec<(String, String)> { let registry = REGISTRY.registry.read().unwrap(); registry .iter() - .map(|(name, fp)| (name.to_string(), fp.actions_str.read().unwrap().clone())) + .map(|(name, fp)| (name.to_string(), fp.actions_str())) .collect() } +fn find_fail_point( + name: &str, +) -> Option> { + let registry = REGISTRY.registry.read().unwrap(); + registry.get(name).map(|p| p.clone()) +} + #[doc(hidden)] -pub fn eval) -> R>(name: &str, f: F) -> Option { - let p = { - let registry = REGISTRY.registry.read().unwrap(); - match registry.get(name) { - None => return None, - Some(p) => p.clone(), - } - }; - p.eval(name).map(f) +pub fn eval) -> R>( + name: &str, + f: F, +) -> Option { + if let Some(p) = find_fail_point(name) { + p.eval(name).map(f) + } else { + None + } +} + +#[doc(hidden)] +#[cfg(feature = "async")] +pub async fn eval_async) -> R>( + name: &str, + f: F, +) -> Option { + if let Some(p) = find_fail_point(name) { + p.eval_async(name).await.map(f) + } else { + None + } } /// Configure the actions for a fail point at runtime. @@ -741,8 +849,7 @@ impl FailGuard { } } -fn set( - registry: &mut HashMap>, +fn set(registry: &mut HashMap>, name: String, actions: &str, ) -> Result<(), String> { @@ -842,6 +949,90 @@ macro_rules! fail_point { }}; } +/// Define an async fail point (requires `failpoints` and `async` features). +/// +/// The `fail_point_async!` macro is similar to `fail_point` except that it +/// can be safely used in an async function. Similar to `fail_point`, it +/// has three forms, and they all take a name as the +/// first argument. The simplest form takes only a name and is suitable for +/// executing most fail point behavior, including panicking, but not for early +/// return or conditional execution based on a local flag. +/// +/// The three forms of fail points look as follows. +/// +/// 1. A basic fail point: +/// +/// ```rust, ignore +/// # #[macro_use] extern crate fail; +/// async fn function_return_unit() { +/// fail_point_async!("fail-point-1"); +/// } +/// ``` +/// +/// This form of fail point can be configured to panic, print, sleep, pause, etc., but +/// not to return from the function early. +/// +/// 2. A fail point that may return early: +/// +/// ```rust, ignore +/// # #[macro_use] extern crate fail; +/// async fn function_return_value() -> u64 { +/// fail_point_async!("fail-point-2", |r| r.map_or(2, |e| e.parse().unwrap())); +/// 0 +/// } +/// ``` +/// +/// This form of fail point can additionally be configured to return early from +/// the enclosing function. It accepts a closure, which itself accepts an +/// `Option`, and is expected to transform that argument into the early +/// return value. The argument string is sourced from the fail point +/// configuration string. For example configuring this "fail-point-2" as +/// "return(100)" will execute the fail point closure, passing it a `Some` value +/// containing a `String` equal to "100"; the closure then parses it into the +/// return value. +/// +/// 3. A fail point with conditional execution: +/// +/// ```rust, ignore +/// # #[macro_use] extern crate fail; +/// async fn function_conditional(enable: bool) { +/// fail_point_async!("fail-point-3", enable, |_| {}); +/// } +/// ``` +/// +/// In this final form, the second argument is a local boolean expression that +/// must evaluate to `true` before the fail point is evaluated. The third +/// argument is again an early-return closure. +/// +/// The three macro arguments (or "designators") are called `$name`, `$cond`, +/// and `$e`. `$name` must be `&str`, `$cond` must be a boolean expression, +/// and`$e` must be a function or closure that accepts an `Option` and +/// returns the same type as the enclosing function. +/// +/// For more examples see the [crate documentation](index.html). For more +/// information about controlling fail points see the [`cfg`](fn.cfg.html) +/// function. +#[macro_export] +#[cfg(all(feature = "failpoints", feature = "async"))] +macro_rules! fail_point_async { + ($registry:expr, $name:expr) => {{ + $crate::eval_async($registry, $name, |_| { + panic!("Return is not supported for the fail point \"{}\"", $name); + }).await; + }}; + ($registry:expr, $name:expr, $e:expr) => {{ + if let Some(res) = $crate::eval_async($registry, $name, $e).await { + return res; + } + }}; + ($registry:expr, $name:expr, $cond:expr, $e:expr) => {{ + if $cond { + $crate::fail_point_async!($registry, $name, $e); + } + }}; +} + + /// Define a fail point (disabled, see `failpoints` feature). #[macro_export] #[cfg(not(feature = "failpoints"))] @@ -940,6 +1131,54 @@ mod tests { rx.recv_timeout(Duration::from_secs(1)).unwrap(); } + + #[cfg(feature = "async")] + #[tokio::test] + async fn test_async_pause() { + let point = Arc::new(FailPoint::new()); + point.set_actions("", vec![Action::new(Task::Pause, 1.0, None)]); + let p = point.clone(); + let (tx, mut rx) = tokio::sync::mpsc::channel(2); + let handle = tokio::spawn(async move { + assert_eq!(p.eval_async("test_fail_point_pause").await, None); + tx.send(()).await.unwrap() + }); + assert!(rx.try_recv().is_err()); + point.set_actions("", vec![Action::new(Task::Off, 1.0, None)]); + rx.recv().await.unwrap(); + assert!(tokio::join!(handle).0.is_ok()); + } + + #[cfg(feature = "async")] + #[tokio::test(flavor="current_thread", start_paused=true)] + async fn test_async_sleep() { + let value = Arc::new(atomic::AtomicU64::new(0)); + + fn spawn_sleep_task( + sleep_duration_millis: u64, + value: Arc, + value_to_set: u64 + ) -> tokio::task::JoinHandle<()> { + let point = Arc::new(FailPoint::new()); + point.set_actions("", vec![Action::new( + Task::Sleep(sleep_duration_millis), 1.0, None) + ]); + let p = point.clone(); + tokio::spawn(async move { + assert_eq!(p.eval_async("test_fail_point_sleep").await, None); + value.store(value_to_set, Ordering::Relaxed); + }) + } + + let h1 = spawn_sleep_task(10, value.clone(), 10); + let h2 = spawn_sleep_task(5, value.clone(), 5); + + assert!(tokio::join!(h2).0.is_ok()); + assert_eq!(value.load(Ordering::Relaxed), 5); + assert!(tokio::join!(h1).0.is_ok()); + assert_eq!(value.load(Ordering::Relaxed), 10); + } + #[test] fn test_yield() { let point = FailPoint::new();