diff --git a/limitador/src/storage/atomic_expiring_value.rs b/limitador/src/storage/atomic_expiring_value.rs index bb5771ca..c42b7656 100644 --- a/limitador/src/storage/atomic_expiring_value.rs +++ b/limitador/src/storage/atomic_expiring_value.rs @@ -103,6 +103,7 @@ impl AtomicExpiryTime { false } + #[allow(dead_code)] pub fn merge(&self, other: Self) { let mut other = other; loop { @@ -134,6 +135,10 @@ impl AtomicExpiryTime { } pub fn into_inner(self) -> SystemTime { + self.expires_at() + } + + pub fn expires_at(&self) -> SystemTime { SystemTime::UNIX_EPOCH + Duration::from_micros(self.expiry.load(Ordering::SeqCst)) } } @@ -164,6 +169,12 @@ impl Clone for AtomicExpiringValue { } } +impl From for AtomicExpiryTime { + fn from(value: SystemTime) -> Self { + AtomicExpiryTime::new(value) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/limitador/src/storage/distributed/cr_counter_value.rs b/limitador/src/storage/distributed/cr_counter_value.rs index 33fa589d..704d8c85 100644 --- a/limitador/src/storage/distributed/cr_counter_value.rs +++ b/limitador/src/storage/distributed/cr_counter_value.rs @@ -1,7 +1,6 @@ use crate::storage::atomic_expiring_value::AtomicExpiryTime; use std::collections::btree_map::Entry; use std::collections::BTreeMap; -use std::ops::Not; use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::RwLock; use std::time::{Duration, SystemTime}; @@ -13,6 +12,7 @@ struct CrCounterValue { expiry: AtomicExpiryTime, } +#[allow(dead_code)] impl CrCounterValue { pub fn new(actor: A, time_window: Duration) -> Self { Self { @@ -77,9 +77,12 @@ impl CrCounterValue { } pub fn merge_at(&self, other: Self, when: SystemTime) { - if self.expiry.expired_at(when).not() && other.expiry.expired_at(when).not() { - let (expiry, other_values) = other.into_inner(); - let _ = self.expiry.merge_at(AtomicExpiryTime::new(expiry), when); + let (expiry, other_values) = other.into_inner(); + if expiry > when { + let _ = self.expiry.merge_at(expiry.into(), when); + if self.expiry.expired_at(when) { + self.reset(expiry); + } let ourselves = self.value.load(Ordering::SeqCst); let mut others = self.others.write().unwrap(); for (actor, other_value) in other_values { @@ -116,6 +119,13 @@ impl CrCounterValue { map.insert(ourselves, value.into_inner()); (expiry.into_inner(), map) } + + fn reset(&self, expiry: SystemTime) { + let mut guard = self.others.write().unwrap(); + self.expiry.update(expiry); + self.value.store(0, Ordering::SeqCst); + guard.clear() + } } #[cfg(test)] @@ -168,42 +178,80 @@ mod tests { #[test] fn merges() { let window = Duration::from_secs(1); - { - let a = CrCounterValue::new('A', window); - let b = CrCounterValue::new('B', window); - a.inc(3, window); - b.inc(2, window); - a.merge(b); - assert_eq!(a.read(), 5); - } + let a = CrCounterValue::new('A', window); + let b = CrCounterValue::new('B', window); + a.inc(3, window); + b.inc(2, window); + a.merge(b); + assert_eq!(a.read(), 5); + } - { - let a = CrCounterValue::new('A', window); - let b = CrCounterValue::new('B', window); - a.inc(3, window); - b.inc(2, window); - b.merge(a); - assert_eq!(b.read(), 5); - } + #[test] + fn merges_symetric() { + let window = Duration::from_secs(1); + let a = CrCounterValue::new('A', window); + let b = CrCounterValue::new('B', window); + a.inc(3, window); + b.inc(2, window); + b.merge(a); + assert_eq!(b.read(), 5); + } - { - let a = CrCounterValue::new('A', window); - let b = CrCounterValue::new('B', window); - a.inc(3, window); - b.inc(2, window); - b.inc_actor('A', 2, window); // older value! - b.merge(a); // merges the 3 - assert_eq!(b.read(), 5); - } + #[test] + fn merges_overrides_with_larger_value() { + let window = Duration::from_secs(1); + let a = CrCounterValue::new('A', window); + let b = CrCounterValue::new('B', window); + a.inc(3, window); + b.inc(2, window); + b.inc_actor('A', 2, window); // older value! + b.merge(a); // merges the 3 + assert_eq!(b.read(), 5); + } - { - let a = CrCounterValue::new('A', window); - let b = CrCounterValue::new('B', window); - a.inc(3, window); - b.inc(2, window); - b.inc_actor('A', 5, window); // newer value! - b.merge(a); // ignores the 3 and keeps its own 5 for a - assert_eq!(b.read(), 7); - } + #[test] + fn merges_ignore_lesser_values() { + let window = Duration::from_secs(1); + let a = CrCounterValue::new('A', window); + let b = CrCounterValue::new('B', window); + a.inc(3, window); + b.inc(2, window); + b.inc_actor('A', 5, window); // newer value! + b.merge(a); // ignores the 3 and keeps its own 5 for a + assert_eq!(b.read(), 7); + } + + #[test] + fn merge_ignores_expired_sets() { + let window = Duration::from_secs(1); + let a = CrCounterValue::new('A', Duration::ZERO); + a.inc(3, Duration::ZERO); + let b = CrCounterValue::new('B', window); + b.inc(2, window); + b.merge(a); + assert_eq!(b.read(), 2); + } + + #[test] + fn merge_ignores_expired_sets_symmetric() { + let window = Duration::from_secs(1); + let a = CrCounterValue::new('A', Duration::ZERO); + a.inc(3, Duration::ZERO); + let b = CrCounterValue::new('B', window); + b.inc(2, window); + a.merge(b); + assert_eq!(a.read(), 2); + } + + #[test] + fn merge_uses_earliest_expiry() { + let later = Duration::from_secs(1); + let a = CrCounterValue::new('A', later); + let sooner = Duration::from_millis(200); + let b = CrCounterValue::new('B', sooner); + a.inc(3, later); + b.inc(2, later); + a.merge(b); + assert!(a.expiry.duration() < sooner); } }