diff --git a/Cargo.lock b/Cargo.lock index 0fde7c82..da5b99a5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1570,7 +1570,6 @@ dependencies = [ "thiserror", "tokio", "tracing", - "ttl_cache", ] [[package]] @@ -1609,12 +1608,6 @@ dependencies = [ "url", ] -[[package]] -name = "linked-hash-map" -version = "0.5.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0717cef1bc8b636c6e1c1bbdefc09e6322da8a9321966e8928ef80d20f7f770f" - [[package]] name = "linux-raw-sys" version = "0.4.11" @@ -3433,15 +3426,6 @@ version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3528ecfd12c466c6f163363caf2d02a71161dd5e1cc6ae7b34207ea2d42d81ed" -[[package]] -name = "ttl_cache" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4189890526f0168710b6ee65ceaedf1460c48a14318ceec933cb26baa492096a" -dependencies = [ - "linked-hash-map", -] - [[package]] name = "typenum" version = "1.17.0" diff --git a/limitador-server/examples/limits.yaml b/limitador-server/examples/limits.yaml index 4178b413..f0ea815b 100644 --- a/limitador-server/examples/limits.yaml +++ b/limitador-server/examples/limits.yaml @@ -14,4 +14,4 @@ conditions: - "req.method == 'POST'" variables: - - user_id \ No newline at end of file + - user_id diff --git a/limitador/Cargo.toml b/limitador/Cargo.toml index c2591da4..380864e1 100644 --- a/limitador/Cargo.toml +++ b/limitador/Cargo.toml @@ -23,7 +23,6 @@ lenient_conditions = [] [dependencies] moka = "0.11.2" getrandom = { version = "0.2", features = ["js"] } -ttl_cache = "0.5" serde = { version = "1", features = ["derive"] } postcard = { version = "1.0.4", features = ["use-std"] } serde_json = "1" diff --git a/limitador/src/storage/atomic_expiring_value.rs b/limitador/src/storage/atomic_expiring_value.rs index 634b9e35..e13b7c69 100644 --- a/limitador/src/storage/atomic_expiring_value.rs +++ b/limitador/src/storage/atomic_expiring_value.rs @@ -5,22 +5,19 @@ use std::time::{Duration, SystemTime, UNIX_EPOCH}; #[derive(Debug)] pub(crate) struct AtomicExpiringValue { value: AtomicI64, - expiry: AtomicU64, // in microseconds + expiry: AtomicExpiryTime, } impl AtomicExpiringValue { pub fn new(value: i64, expiry: SystemTime) -> Self { - let expiry = Self::get_duration_micros(expiry); Self { value: AtomicI64::new(value), - expiry: AtomicU64::new(expiry), + expiry: AtomicExpiryTime::new(expiry), } } pub fn value_at(&self, when: SystemTime) -> i64 { - let when = Self::get_duration_micros(when); - let expiry = self.expiry.load(Ordering::SeqCst); - if expiry <= when { + if self.expiry.expired_at(when) { return 0; } self.value.load(Ordering::SeqCst) @@ -31,25 +28,49 @@ impl AtomicExpiringValue { } pub fn update(&self, delta: i64, ttl: u64, when: SystemTime) -> i64 { - let ttl_micros = ttl * 1_000_000; - let when_micros = Self::get_duration_micros(when); - - let expiry = self.expiry.load(Ordering::SeqCst); - if expiry <= when_micros { - let new_expiry = when_micros + ttl_micros; - if self - .expiry - .compare_exchange(expiry, new_expiry, Ordering::SeqCst, Ordering::SeqCst) - .is_ok() - { - self.value.store(delta, Ordering::SeqCst); - } + if self.expiry.update_if_expired(ttl, when) { + self.value.store(delta, Ordering::SeqCst); return delta; } self.value.fetch_add(delta, Ordering::SeqCst) + delta } pub fn ttl(&self) -> Duration { + self.expiry.duration() + } + + #[allow(dead_code)] + pub fn set(&self, value: i64, ttl: Duration) { + self.expiry.update(ttl); + self.value.store(value, Ordering::SeqCst); + } +} + +#[derive(Debug)] +pub struct AtomicExpiryTime { + expiry: AtomicU64, // in microseconds +} + +impl AtomicExpiryTime { + pub fn new(when: SystemTime) -> Self { + let expiry = Self::since_epoch(when); + Self { + expiry: AtomicU64::new(expiry), + } + } + + #[allow(dead_code)] + pub fn from_now(ttl: Duration) -> Self { + Self::new(SystemTime::now() + ttl) + } + + fn since_epoch(when: SystemTime) -> u64 { + when.duration_since(UNIX_EPOCH) + .expect("SystemTime before UNIX EPOCH!") + .as_micros() as u64 + } + + pub fn duration(&self) -> Duration { let expiry = SystemTime::UNIX_EPOCH + Duration::from_micros(self.expiry.load(Ordering::SeqCst)); expiry @@ -57,10 +78,37 @@ impl AtomicExpiringValue { .unwrap_or(Duration::ZERO) } - fn get_duration_micros(when: SystemTime) -> u64 { - when.duration_since(UNIX_EPOCH) - .expect("SystemTime before UNIX EPOCH!") - .as_micros() as u64 + pub fn expired_at(&self, when: SystemTime) -> bool { + let when = Self::since_epoch(when); + self.expiry.load(Ordering::SeqCst) <= when + } + + #[allow(dead_code)] + pub fn update(&self, ttl: Duration) { + self.expiry + .store(Self::since_epoch(SystemTime::now() + ttl), Ordering::SeqCst); + } + + pub fn update_if_expired(&self, ttl: u64, when: SystemTime) -> bool { + let ttl_micros = ttl * 1_000_000; + let when_micros = Self::since_epoch(when); + let expiry = self.expiry.load(Ordering::SeqCst); + if expiry <= when_micros { + let new_expiry = when_micros + ttl_micros; + return self + .expiry + .compare_exchange(expiry, new_expiry, Ordering::SeqCst, Ordering::SeqCst) + .is_ok(); + } + false + } +} + +impl Clone for AtomicExpiryTime { + fn clone(&self) -> Self { + Self { + expiry: AtomicU64::new(self.expiry.load(Ordering::SeqCst)), + } } } @@ -68,7 +116,7 @@ impl Default for AtomicExpiringValue { fn default() -> Self { AtomicExpiringValue { value: AtomicI64::new(0), - expiry: AtomicU64::new(0), + expiry: AtomicExpiryTime::new(UNIX_EPOCH), } } } @@ -77,7 +125,7 @@ impl Clone for AtomicExpiringValue { fn clone(&self) -> Self { AtomicExpiringValue { value: AtomicI64::new(self.value.load(Ordering::SeqCst)), - expiry: AtomicU64::new(self.expiry.load(Ordering::SeqCst)), + expiry: self.expiry.clone(), } } } diff --git a/limitador/src/storage/redis/counters_cache.rs b/limitador/src/storage/redis/counters_cache.rs index 01f8a146..603fc3ac 100644 --- a/limitador/src/storage/redis/counters_cache.rs +++ b/limitador/src/storage/redis/counters_cache.rs @@ -1,15 +1,63 @@ use crate::counter::Counter; +use crate::storage::atomic_expiring_value::{AtomicExpiringValue, AtomicExpiryTime}; use crate::storage::redis::{ DEFAULT_MAX_CACHED_COUNTERS, DEFAULT_MAX_TTL_CACHED_COUNTERS_SEC, DEFAULT_TTL_RATIO_CACHED_COUNTERS, }; -use std::time::Duration; -use ttl_cache::TtlCache; +use moka::sync::Cache; +use std::sync::Arc; +use std::time::{Duration, SystemTime}; + +pub struct CachedCounterValue { + value: AtomicExpiringValue, + expiry: AtomicExpiryTime, +} pub struct CountersCache { max_ttl_cached_counters: Duration, pub ttl_ratio_cached_counters: u64, - cache: TtlCache, + cache: Cache>, +} + +impl CachedCounterValue { + pub fn from(counter: &Counter, value: i64, ttl: Duration) -> Self { + let now = SystemTime::now(); + Self { + value: AtomicExpiringValue::new(value, now + Duration::from_secs(counter.seconds())), + expiry: AtomicExpiryTime::from_now(ttl), + } + } + + pub fn expired_at(&self, now: SystemTime) -> bool { + self.expiry.expired_at(now) + } + + pub fn set_from_authority(&self, counter: &Counter, value: i64, expiry: Duration) { + let time_window = Duration::from_secs(counter.seconds()); + self.value.set(value, time_window); + self.expiry.update(expiry); + } + + pub fn delta(&self, counter: &Counter, delta: i64) -> i64 { + self.value + .update(delta, counter.seconds(), SystemTime::now()) + } + + pub fn hits(&self, _: &Counter) -> i64 { + self.value.value_at(SystemTime::now()) + } + + pub fn remaining(&self, counter: &Counter) -> i64 { + counter.max_value() - self.hits(counter) + } + + pub fn is_limited(&self, counter: &Counter, delta: i64) -> bool { + self.hits(counter) as i128 + delta as i128 > counter.max_value() as i128 + } + + pub fn to_next_window(&self) -> Duration { + self.value.ttl() + } } pub struct CountersCacheBuilder { @@ -46,40 +94,51 @@ impl CountersCacheBuilder { CountersCache { max_ttl_cached_counters: self.max_ttl_cached_counters, ttl_ratio_cached_counters: self.ttl_ratio_cached_counters, - cache: TtlCache::new(self.max_cached_counters), + cache: Cache::new(self.max_cached_counters as u64), } } } impl CountersCache { - pub fn get(&self, counter: &Counter) -> Option { - self.cache.get(counter).copied() + pub fn get(&self, counter: &Counter) -> Option> { + self.cache.get(counter) } pub fn insert( - &mut self, + &self, counter: Counter, redis_val: Option, redis_ttl_ms: i64, ttl_margin: Duration, - ) { + now: SystemTime, + ) -> Arc { let counter_val = redis_val.unwrap_or(0); - let counter_ttl = self.ttl_from_redis_ttl( + let cache_ttl = self.ttl_from_redis_ttl( redis_ttl_ms, counter.seconds(), counter_val, counter.max_value(), ); - if let Some(ttl) = counter_ttl.checked_sub(ttl_margin) { - if ttl > Duration::from_secs(0) { - self.cache.insert(counter, counter_val, ttl); + if let Some(ttl) = cache_ttl.checked_sub(ttl_margin) { + if ttl > Duration::ZERO { + let value = CachedCounterValue::from(&counter, counter_val, cache_ttl); + let previous = self.cache.get_with(counter.clone(), || Arc::new(value)); + if previous.expired_at(now) { + previous.set_from_authority(&counter, counter_val, cache_ttl); + } + return previous; } } + Arc::new(CachedCounterValue::from( + &counter, + counter_val, + Duration::ZERO, + )) } - pub fn increase_by(&mut self, counter: &Counter, delta: i64) { - if let Some(val) = self.cache.get_mut(counter) { - *val += delta + pub fn increase_by(&self, counter: &Counter, delta: i64) { + if let Some(val) = self.cache.get(counter) { + val.delta(counter, delta); }; } @@ -149,8 +208,14 @@ mod tests { values, ); - let mut cache = CountersCacheBuilder::new().build(); - cache.insert(counter.clone(), Some(10), 10, Duration::from_secs(0)); + let cache = CountersCacheBuilder::new().build(); + cache.insert( + counter.clone(), + Some(10), + 10, + Duration::from_secs(0), + SystemTime::now(), + ); assert!(cache.get(&counter).is_some()); } @@ -192,15 +257,19 @@ mod tests { values, ); - let mut cache = CountersCacheBuilder::new().build(); + let cache = CountersCacheBuilder::new().build(); cache.insert( counter.clone(), Some(current_value), 10, Duration::from_secs(0), + SystemTime::now(), ); - assert_eq!(cache.get(&counter).unwrap(), current_value); + assert_eq!( + cache.get(&counter).map(|e| e.hits(&counter)).unwrap(), + current_value + ); } #[test] @@ -219,10 +288,16 @@ mod tests { values, ); - let mut cache = CountersCacheBuilder::new().build(); - cache.insert(counter.clone(), None, 10, Duration::from_secs(0)); + let cache = CountersCacheBuilder::new().build(); + cache.insert( + counter.clone(), + None, + 10, + Duration::from_secs(0), + SystemTime::now(), + ); - assert_eq!(cache.get(&counter).unwrap(), 0); + assert_eq!(cache.get(&counter).map(|e| e.hits(&counter)).unwrap(), 0); } #[test] @@ -242,15 +317,19 @@ mod tests { values, ); - let mut cache = CountersCacheBuilder::new().build(); + let cache = CountersCacheBuilder::new().build(); cache.insert( counter.clone(), Some(current_val), 10, Duration::from_secs(0), + SystemTime::now(), ); cache.increase_by(&counter, increase_by); - assert_eq!(cache.get(&counter).unwrap(), current_val + increase_by); + assert_eq!( + cache.get(&counter).map(|e| e.hits(&counter)).unwrap(), + (current_val + increase_by) + ); } } diff --git a/limitador/src/storage/redis/redis_cached.rs b/limitador/src/storage/redis/redis_cached.rs index d5731ee0..c08989b5 100644 --- a/limitador/src/storage/redis/redis_cached.rs +++ b/limitador/src/storage/redis/redis_cached.rs @@ -1,5 +1,6 @@ use crate::counter::Counter; use crate::limit::Limit; +use crate::storage::atomic_expiring_value::AtomicExpiringValue; use crate::storage::keys::*; use crate::storage::redis::counters_cache::{CountersCache, CountersCacheBuilder}; use crate::storage::redis::redis_async::AsyncRedisStorage; @@ -15,8 +16,8 @@ use redis::{ConnectionInfo, RedisError}; use std::collections::{HashMap, HashSet}; use std::str::FromStr; use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::{Arc, Mutex, MutexGuard}; -use std::time::{Duration, Instant}; +use std::sync::{Arc, Mutex}; +use std::time::{Duration, Instant, SystemTime}; use tracing::{error, warn}; // This is just a first version. @@ -38,8 +39,8 @@ use tracing::{error, warn}; // multiple times when it is not cached. pub struct CachedRedisStorage { - cached_counters: Mutex, - batcher_counter_updates: Arc>>, + cached_counters: CountersCache, + batcher_counter_updates: Arc>>, async_redis_storage: AsyncRedisStorage, redis_conn_manager: ConnectionManager, partitioned: Arc, @@ -75,31 +76,27 @@ impl AsyncCounterStorage for CachedRedisStorage { let mut not_cached: Vec<&mut Counter> = vec![]; let mut first_limited = None; + let now = SystemTime::now(); // Check cached counters - { - let cached_counters = self.cached_counters.lock().unwrap(); - for counter in counters.iter_mut() { - match cached_counters.get(counter) { - Some(val) => { - if first_limited.is_none() && val + delta > counter.max_value() { - let a = Authorization::Limited( - counter.limit().name().map(|n| n.to_owned()), - ); - if !load_counters { - return Ok(a); - } - first_limited = Some(a); - } - if load_counters { - counter.set_remaining(counter.max_value() - val - delta); - // todo: how do we get the ttl for this entry? - // counter.set_expires_in(Duration::from_secs(counter.seconds())); + for counter in counters.iter_mut() { + match self.cached_counters.get(counter) { + Some(val) if !val.expired_at(now) => { + if first_limited.is_none() && val.is_limited(counter, delta) { + let a = + Authorization::Limited(counter.limit().name().map(|n| n.to_owned())); + if !load_counters { + return Ok(a); } + first_limited = Some(a); } - None => { - not_cached.push(counter); + if load_counters { + counter.set_remaining(val.remaining(counter) - delta); + counter.set_expires_in(val.to_next_window()); } } + _ => { + not_cached.push(counter); + } } } @@ -127,31 +124,23 @@ impl AsyncCounterStorage for CachedRedisStorage { let ttl_margin = Duration::from_millis((Instant::now() - time_start_get_ttl).as_millis() as u64); - { - let mut cached_counters = self.cached_counters.lock().unwrap(); - for (i, counter) in not_cached.iter_mut().enumerate() { - cached_counters.insert( - counter.clone(), - counter_vals[i], - counter_ttls_msecs[i], - ttl_margin, - ); - let remaining = counter.max_value() - counter_vals[i].unwrap_or(0) - delta; - if first_limited.is_none() && remaining < 0 { - first_limited = Some(Authorization::Limited( - counter.limit().name().map(|n| n.to_owned()), - )); - } - if load_counters { - counter.set_remaining(remaining); - let counter_ttl = if counter_ttls_msecs[i] >= 0 { - Duration::from_millis(counter_ttls_msecs[i] as u64) - } else { - Duration::from_secs(counter.max_value() as u64) - }; - - counter.set_expires_in(counter_ttl); - } + for (i, counter) in not_cached.iter_mut().enumerate() { + let cached_value = self.cached_counters.insert( + counter.clone(), + counter_vals[i], + counter_ttls_msecs[i], + ttl_margin, + now, + ); + let remaining = cached_value.remaining(counter); + if first_limited.is_none() && remaining <= 0 { + first_limited = Some(Authorization::Limited( + counter.limit().name().map(|n| n.to_owned()), + )); + } + if load_counters { + counter.set_remaining(remaining - delta); + counter.set_expires_in(cached_value.to_next_window()); } } } @@ -161,17 +150,28 @@ impl AsyncCounterStorage for CachedRedisStorage { } // Update cached values - { - let mut cached_counters = self.cached_counters.lock().unwrap(); - for counter in counters.iter() { - cached_counters.increase_by(counter, delta); - } + for counter in counters.iter() { + self.cached_counters.increase_by(counter, delta); } // Batch or update depending on configuration let mut batcher = self.batcher_counter_updates.lock().unwrap(); + let now = SystemTime::now(); for counter in counters.iter() { - Self::batch_counter(delta, &mut batcher, counter); + match batcher.get_mut(counter) { + Some(val) => { + val.update(delta, counter.seconds(), now); + } + None => { + batcher.insert( + counter.clone(), + AtomicExpiringValue::new( + delta, + now + Duration::from_secs(counter.seconds()), + ), + ); + } + } } Ok(Authorization::Ok) @@ -231,7 +231,8 @@ impl CachedRedisStorage { AsyncRedisStorage::new_with_conn_manager(redis_conn_manager.clone()); let storage = async_redis_storage.clone(); - let batcher: Arc>> = Arc::new(Mutex::new(Default::default())); + let batcher: Arc>> = + Arc::new(Mutex::new(Default::default())); let p = Arc::clone(&partitioned); let batcher_flusher = batcher.clone(); let mut interval = tokio::time::interval(flushing_period); @@ -247,19 +248,16 @@ impl CachedRedisStorage { let mut batch = batcher_flusher.lock().unwrap(); std::mem::take(&mut *batch) }; + let now = SystemTime::now(); for (counter, delta) in counters { - storage - .update_counter(&counter, delta) - .await - .or_else(|err| { - if err.is_transient() { - p.store(true, Ordering::Release); - Ok(()) - } else { - Err(err) - } - }) - .expect("Unrecoverable Redis error!"); + let delta = delta.value_at(now); + if delta > 0 { + storage + .update_counter(&counter, delta) + .await + .or_else(|err| if err.is_transient() { Ok(()) } else { Err(err) }) + .expect("Unrecoverable Redis error!"); + } } } interval.tick().await; @@ -273,7 +271,7 @@ impl CachedRedisStorage { .build(); Ok(Self { - cached_counters: Mutex::new(cached_counters), + cached_counters, batcher_counter_updates: batcher, redis_conn_manager, async_redis_storage, @@ -336,21 +334,6 @@ impl CachedRedisStorage { Ok((counter_vals, counter_ttls_msecs)) } - - fn batch_counter( - delta: i64, - batcher: &mut MutexGuard>, - counter: &Counter, - ) { - match batcher.get_mut(counter) { - Some(val) => { - *val += delta; - } - None => { - batcher.insert(counter.clone(), delta); - } - } - } } pub struct CachedRedisStorageBuilder {