diff --git a/limitador/src/storage/atomic_expiring_value.rs b/limitador/src/storage/atomic_expiring_value.rs index 0353e041..f8d19ee0 100644 --- a/limitador/src/storage/atomic_expiring_value.rs +++ b/limitador/src/storage/atomic_expiring_value.rs @@ -28,8 +28,8 @@ impl AtomicExpiringValue { } #[allow(dead_code)] - pub fn add_and_set_expiry(&self, delta: u64, expire_at: SystemTime) -> u64 { - self.expiry.update(expire_at); + pub fn add_and_set_expiry(&self, delta: u64, expiry: SystemTime) -> u64 { + self.expiry.update(expiry); self.value.fetch_add(delta, Ordering::SeqCst) + delta } diff --git a/limitador/src/storage/redis/counters_cache.rs b/limitador/src/storage/redis/counters_cache.rs index 2d6e31f9..ef28413b 100644 --- a/limitador/src/storage/redis/counters_cache.rs +++ b/limitador/src/storage/redis/counters_cache.rs @@ -297,32 +297,25 @@ impl CountersCache { counter: Counter, redis_val: u64, remote_deltas: u64, - redis_expiry: i64, + expiry: SystemTime, ) -> Arc { - if redis_expiry > 0 { - let expiry_ts = SystemTime::UNIX_EPOCH + Duration::from_millis(redis_expiry as u64); - if expiry_ts > SystemTime::now() { - let mut from_cache = true; - let cached = self.cache.get_with(counter.clone(), || { + if expiry > SystemTime::now() { + let mut from_cache = true; + let cached = self.cache.get_with(counter.clone(), || { + from_cache = false; + if let Some(entry) = self.batcher.updates.get(&counter) { gauge!("cache_size").increment(1); - from_cache = false; - if let Some(entry) = self.batcher.updates.get(&counter) { - let cached_value = entry.value(); - cached_value.add_from_authority( - remote_deltas, - expiry_ts, - counter.max_value(), - ); - cached_value.clone() - } else { - Arc::new(CachedCounterValue::from_authority(&counter, redis_val)) - } - }); - if from_cache { - cached.add_from_authority(remote_deltas, expiry_ts, counter.max_value()); + let cached_value = entry.value(); + cached_value.add_from_authority(remote_deltas, expiry, counter.max_value()); + cached_value.clone() + } else { + Arc::new(CachedCounterValue::from_authority(&counter, redis_val)) } - return cached; + }); + if from_cache { + cached.add_from_authority(remote_deltas, expiry, counter.max_value()); } + return cached; } Arc::new(CachedCounterValue::load_from_authority_asap( &counter, redis_val, @@ -385,7 +378,6 @@ impl CountersCacheBuilder { mod tests { use std::collections::HashMap; use std::ops::Add; - use std::time::UNIX_EPOCH; use crate::limit::Limit; @@ -613,11 +605,7 @@ mod tests { counter.clone(), 10, 0, - SystemTime::now() - .add(Duration::from_secs(1)) - .duration_since(UNIX_EPOCH) - .unwrap() - .as_micros() as i64, + SystemTime::now().add(Duration::from_secs(1)), ); assert!(cache.get(&counter).is_some()); @@ -643,11 +631,7 @@ mod tests { counter.clone(), current_value, 0, - SystemTime::now() - .add(Duration::from_secs(1)) - .duration_since(UNIX_EPOCH) - .unwrap() - .as_micros() as i64, + SystemTime::now().add(Duration::from_secs(1)), ); assert_eq!( @@ -667,11 +651,7 @@ mod tests { counter.clone(), current_val, 0, - SystemTime::now() - .add(Duration::from_secs(1)) - .duration_since(UNIX_EPOCH) - .unwrap() - .as_micros() as i64, + SystemTime::now().add(Duration::from_secs(1)), ); cache.increase_by(&counter, increase_by).await; diff --git a/limitador/src/storage/redis/redis_cached.rs b/limitador/src/storage/redis/redis_cached.rs index 3e1125bf..0b9689e2 100644 --- a/limitador/src/storage/redis/redis_cached.rs +++ b/limitador/src/storage/redis/redis_cached.rs @@ -19,7 +19,7 @@ use std::collections::{HashMap, HashSet}; use std::str::FromStr; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; -use std::time::Duration; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; use tracing::{debug_span, error, warn, Instrument}; // This is just a first version. @@ -284,14 +284,15 @@ impl CachedRedisStorageBuilder { async fn update_counters( redis_conn: &mut C, counters_and_deltas: HashMap>, -) -> Result, (Vec<(Counter, u64, u64, i64)>, StorageErr)> { +) -> Result, (Vec<(Counter, u64, u64, SystemTime)>, StorageErr)> { let redis_script = redis::Script::new(BATCH_UPDATE_COUNTERS); let mut script_invocation = redis_script.prepare_invoke(); let res = if counters_and_deltas.is_empty() { Default::default() } else { - let mut res: Vec<(Counter, u64, u64, i64)> = Vec::with_capacity(counters_and_deltas.len()); + let mut res: Vec<(Counter, u64, u64, SystemTime)> = + Vec::with_capacity(counters_and_deltas.len()); for (counter, value) in counters_and_deltas { let (delta, last_value_from_redis) = value @@ -303,7 +304,7 @@ async fn update_counters( script_invocation.arg(counter.window().as_secs()); script_invocation.arg(delta); // We need to store the counter in the actual order we are sending it to the script - res.push((counter, last_value_from_redis, delta, 0)); + res.push((counter, 0, last_value_from_redis, UNIX_EPOCH)); } } @@ -329,7 +330,8 @@ async fn update_counters( .unwrap_or(0) .saturating_sub(*val); // new value - previous one = remote writes *val = u64::try_from(script_res[j]).unwrap_or(0); // update to value to newest - *expires_at = script_res[j + 1]; + *expires_at = + UNIX_EPOCH + Duration::from_millis(u64::try_from(script_res[j + 1]).unwrap_or(0)); } res }; @@ -444,13 +446,15 @@ mod tests { arc.delta(&counter, LOCAL_INCREMENTS); counters_and_deltas.insert(counter.clone(), arc); - let one_sec_from_now = SystemTime::now() - .add(Duration::from_secs(1)) - .duration_since(UNIX_EPOCH) - .unwrap(); + let one_sec_from_now = SystemTime::now().add(Duration::from_secs(1)); let mock_response = Value::Bulk(vec![ Value::Int(NEW_VALUE_FROM_REDIS as i64), - Value::Int(one_sec_from_now.as_millis() as i64), + Value::Int( + one_sec_from_now + .duration_since(UNIX_EPOCH) + .unwrap() + .as_millis() as i64, + ), ]); let mut mock_client = MockRedisConnection::new(vec![MockCmd::new( @@ -475,7 +479,13 @@ mod tests { NEW_VALUE_FROM_REDIS - INITIAL_VALUE_FROM_REDIS - LOCAL_INCREMENTS, remote_increments ); - assert_eq!(one_sec_from_now.as_millis(), expire_at as u128); + assert_eq!( + one_sec_from_now + .duration_since(UNIX_EPOCH) + .unwrap() + .as_millis(), + expire_at.duration_since(UNIX_EPOCH).unwrap().as_millis() + ); } #[tokio::test]