diff --git a/chitchat/Cargo.toml b/chitchat/Cargo.toml index f486241..453454e 100644 --- a/chitchat/Cargo.toml +++ b/chitchat/Cargo.toml @@ -29,9 +29,16 @@ zstd = "0.13" [dev-dependencies] assert-json-diff = "2" -mock_instant = "0.3" tracing-subscriber = "0.3" proptest = "1.4" +tokio = { version = "1.28.0", features = [ + "net", + "sync", + "rt-multi-thread", + "macros", + "test-util", + "time", +] } [features] testsuite = [] diff --git a/chitchat/src/failure_detector.rs b/chitchat/src/failure_detector.rs index a783955..d8b8038 100644 --- a/chitchat/src/failure_detector.rs +++ b/chitchat/src/failure_detector.rs @@ -1,11 +1,8 @@ use std::collections::{HashMap, HashSet}; use std::time::Duration; -#[cfg(not(test))] -use std::time::Instant; -#[cfg(test)] -use mock_instant::Instant; use serde::{Deserialize, Serialize}; +use tokio::time::Instant; use tracing::debug; use crate::ChitchatId; @@ -287,15 +284,15 @@ impl BoundedArrayStats { mod tests { use std::time::Duration; - use mock_instant::MockClock; use rand::prelude::*; use super::{BoundedArrayStats, SamplingWindow}; use crate::failure_detector::{FailureDetector, FailureDetectorConfig}; use crate::ChitchatId; - #[test] - fn test_failure_detector() { + #[tokio::test] + async fn test_failure_detector() { + tokio::time::pause(); let mut rng = rand::thread_rng(); let mut failure_detector = FailureDetector::new(FailureDetectorConfig::default()); @@ -308,7 +305,7 @@ mod tests { for _ in 0..=2000 { let time_offset = intervals_choices.choose(&mut rng).unwrap(); let chitchat_id = chitchat_ids_choices.choose(&mut rng).unwrap(); - MockClock::advance(Duration::from_secs(*time_offset)); + tokio::time::advance(Duration::from_secs(*time_offset)).await; failure_detector.report_heartbeat(chitchat_id); } @@ -325,7 +322,7 @@ mod tests { assert_eq!(failure_detector.garbage_collect(), Vec::new()); // stop reporting heartbeat for few seconds - MockClock::advance(Duration::from_secs(50)); + tokio::time::advance(Duration::from_secs(50)).await; for chitchat_id in &chitchat_ids_choices { failure_detector.update_node_liveness(chitchat_id); } @@ -338,7 +335,7 @@ mod tests { assert_eq!(failure_detector.garbage_collect(), Vec::new()); // Wait for dead_node_grace_period & garbage collect. - MockClock::advance(Duration::from_secs(25 * 60 * 60)); + tokio::time::advance(Duration::from_secs(25 * 60 * 60)).await; let garbage_collected_nodes = failure_detector.garbage_collect(); assert_eq!( failure_detector @@ -365,8 +362,9 @@ mod tests { ); } - #[test] - fn test_failure_detector_node_state_from_live_to_down_to_live() { + #[tokio::test] + async fn test_failure_detector_node_state_from_live_to_down_to_live() { + tokio::time::pause(); let mut rng = rand::thread_rng(); let mut failure_detector = FailureDetector::new(FailureDetectorConfig::default()); let intervals_choices = [1u64, 2]; @@ -374,7 +372,7 @@ mod tests { for _ in 0..=2000 { let time_offset = intervals_choices.choose(&mut rng).unwrap(); - MockClock::advance(Duration::from_secs(*time_offset)); + tokio::time::advance(Duration::from_secs(*time_offset)).await; failure_detector.report_heartbeat(&node_1); } @@ -388,7 +386,7 @@ mod tests { ); // Check node-1 is down (stop reporting heartbeat). - MockClock::advance(Duration::from_secs(20)); + tokio::time::advance(Duration::from_secs(20)).await; failure_detector.update_node_liveness(&node_1); assert_eq!( failure_detector @@ -401,7 +399,7 @@ mod tests { // Check node-1 is back up (resume reporting heartbeat). for _ in 0..=500 { let time_offset = intervals_choices.choose(&mut rng).unwrap(); - MockClock::advance(Duration::from_secs(*time_offset)); + tokio::time::advance(Duration::from_secs(*time_offset)).await; failure_detector.report_heartbeat(&node_1); } failure_detector.update_node_liveness(&node_1); @@ -414,14 +412,15 @@ mod tests { ); } - #[test] - fn test_failure_detector_node_state_after_initial_interval() { + #[tokio::test] + async fn test_failure_detector_node_state_after_initial_interval() { + tokio::time::pause(); let mut failure_detector = FailureDetector::new(FailureDetectorConfig::default()); let chitchat_id = ChitchatId::for_local_test(10_001); failure_detector.report_heartbeat(&chitchat_id); - MockClock::advance(Duration::from_secs(1)); + tokio::time::advance(Duration::from_secs(1)).await; failure_detector.update_node_liveness(&chitchat_id); let live_nodes = failure_detector @@ -429,7 +428,7 @@ mod tests { .map(|chitchat_id| chitchat_id.node_id.as_str()) .collect::>(); assert_eq!(live_nodes, vec!["node-10001"]); - MockClock::advance(Duration::from_secs(40)); + tokio::time::advance(Duration::from_secs(40)).await; failure_detector.update_node_liveness(&chitchat_id); let live_nodes = failure_detector @@ -439,13 +438,14 @@ mod tests { assert_eq!(live_nodes, Vec::<&str>::new()); } - #[test] - fn test_sampling_window() { + #[tokio::test] + async fn test_sampling_window() { + tokio::time::pause(); let mut sampling_window = SamplingWindow::new(10, Duration::from_secs(5), Duration::from_secs(2)); sampling_window.report_heartbeat(); - MockClock::advance(Duration::from_secs(3)); + tokio::time::advance(Duration::from_secs(3)).await; sampling_window.report_heartbeat(); // Now intervals window is: [2.0, 3.0]. @@ -455,13 +455,13 @@ mod tests { assert!((sampling_window.phi() - (0.0 / mean)).abs() < f64::EPSILON); // 1s elapsed since last reported heartbeat. - MockClock::advance(Duration::from_secs(1)); + tokio::time::advance(Duration::from_secs(1)).await; assert!((sampling_window.phi() - (1.0 / mean)).abs() < f64::EPSILON); // Check reported heartbeat later than max_interval is ignore. - MockClock::advance(Duration::from_secs(5)); + tokio::time::advance(Duration::from_secs(5)).await; sampling_window.report_heartbeat(); - MockClock::advance(Duration::from_secs(2)); + tokio::time::advance(Duration::from_secs(2)).await; assert!( (sampling_window.phi() - (2.0 / mean)).abs() < f64::EPSILON, "Mean value should not change." diff --git a/chitchat/src/lib.rs b/chitchat/src/lib.rs index c31333b..adda914 100644 --- a/chitchat/src/lib.rs +++ b/chitchat/src/lib.rs @@ -350,10 +350,8 @@ mod tests { use std::sync::Arc; use std::time::Duration; - use mock_instant::MockClock; use tokio::sync::Mutex; use tokio::time; - use tokio_stream::wrappers::IntervalStream; use tokio_stream::StreamExt; use super::*; @@ -427,13 +425,6 @@ mod tests { .collect::>(); chitchat_handlers.push(start_node(chitchat_id.clone(), &seeds, transport).await); } - // Make sure the failure detector's fake clock moves forward. - tokio::spawn(async { - let mut ticker = IntervalStream::new(time::interval(Duration::from_millis(50))); - while ticker.next().await.is_some() { - MockClock::advance(Duration::from_millis(50)); - } - }); chitchat_handlers } diff --git a/chitchat/src/state.rs b/chitchat/src/state.rs index d6b78da..17cf7dc 100644 --- a/chitchat/src/state.rs +++ b/chitchat/src/state.rs @@ -3,7 +3,6 @@ use std::collections::{BTreeMap, HashSet}; use std::fmt::{Debug, Formatter}; use std::net::{Ipv4Addr, SocketAddr}; use std::ops::Bound; -use std::time::Instant; use itertools::Itertools; use rand::prelude::SliceRandom; @@ -24,9 +23,6 @@ pub struct NodeState { key_values: BTreeMap, max_version: Version, #[serde(skip)] - #[serde(default = "Instant::now")] - last_heartbeat: Instant, - #[serde(skip)] listeners: Listeners, } @@ -36,7 +32,6 @@ impl Debug for NodeState { .field("heartbeat", &self.heartbeat) .field("key_values", &self.key_values) .field("max_version", &self.max_version) - .field("last_heartbeat", &self.last_heartbeat) .finish() } } @@ -48,7 +43,6 @@ impl NodeState { heartbeat: Heartbeat(0), key_values: Default::default(), max_version: Default::default(), - last_heartbeat: Instant::now(), listeners, } } @@ -63,7 +57,6 @@ impl NodeState { heartbeat: Heartbeat(0), key_values: Default::default(), max_version: Default::default(), - last_heartbeat: Instant::now(), listeners: Listeners::default(), } } @@ -308,7 +301,6 @@ impl ClusterState { .or_insert_with(|| NodeState::new(chitchat_id, self.listeners.clone())); if node_state.heartbeat < heartbeat { node_state.heartbeat = heartbeat; - node_state.last_heartbeat = Instant::now(); } for (key, versioned_value) in key_values { node_state.max_version = node_state.max_version.max(versioned_value.version);