diff --git a/limitador/src/storage/disk/rocksdb_storage.rs b/limitador/src/storage/disk/rocksdb_storage.rs index 05d04f04..a05ea043 100644 --- a/limitador/src/storage/disk/rocksdb_storage.rs +++ b/limitador/src/storage/disk/rocksdb_storage.rs @@ -42,10 +42,50 @@ impl CounterStorage for RocksDbStorage { #[tracing::instrument(skip_all)] fn check_and_update( + &self, + counters: &[Counter], + delta: u64, + ) -> Result { + let mut keys: Vec> = Vec::with_capacity(counters.len()); + + for counter in counters { + let key = key_for_counter(counter); + let slice: &[u8] = key.as_ref(); + let entry = { + let span = debug_span!("datastore"); + let _entered = span.enter(); + self.db.get(slice)? + }; + let (val, _) = match entry { + None => (0, Duration::from_secs(counter.limit().seconds())), + Some(raw) => { + let slice: &[u8] = raw.as_ref(); + let value: ExpiringValue = slice.try_into()?; + (value.value(), value.ttl()) + } + }; + + if counter.max_value() < val + delta { + return Ok(Authorization::Limited( + counter.limit().name().map(|n| n.to_string()), + )); + } + + keys.push(key); + } + + for (idx, counter) in counters.iter().enumerate() { + self.insert_or_update(&keys[idx], counter, delta)?; + } + + Ok(Authorization::Ok) + } + + #[tracing::instrument(skip_all)] + fn check_and_update_loading( &self, counters: &mut Vec, delta: u64, - load_counters: bool, ) -> Result { let mut keys: Vec> = Vec::with_capacity(counters.len()); @@ -66,15 +106,13 @@ impl CounterStorage for RocksDbStorage { } }; - if load_counters { - counter.set_expires_in(ttl); - counter.set_remaining( - counter - .max_value() - .checked_sub(val + delta) - .unwrap_or_default(), - ); - } + counter.set_expires_in(ttl); + counter.set_remaining( + counter + .max_value() + .checked_sub(val + delta) + .unwrap_or_default(), + ); if counter.max_value() < val + delta { return Ok(Authorization::Limited( diff --git a/limitador/src/storage/distributed/mod.rs b/limitador/src/storage/distributed/mod.rs index 81d9d6ed..fe725f0a 100644 --- a/limitador/src/storage/distributed/mod.rs +++ b/limitador/src/storage/distributed/mod.rs @@ -90,25 +90,14 @@ impl CounterStorage for CrInMemoryStorage { #[tracing::instrument(skip_all)] fn check_and_update( &self, - counters: &mut Vec, + counters: &[Counter], delta: u64, - load_counters: bool, ) -> Result { - let mut first_limited = None; let mut counter_values_to_update: Vec> = Vec::new(); let now = SystemTime::now(); - let mut process_counter = - |counter: &mut Counter, value: u64, delta: u64| -> Option { - if load_counters { - let remaining = counter.max_value().checked_sub(value + delta); - counter.set_remaining(remaining.unwrap_or(0)); - if first_limited.is_none() && remaining.is_none() { - first_limited = Some(Authorization::Limited( - counter.limit().name().map(|n| n.to_owned()), - )); - } - } + let process_counter = + |counter: &Counter, value: u64, delta: u64| -> Option { if !Self::counter_is_within_limits(counter, Some(&value), delta) { return Some(Authorization::Limited( counter.limit().name().map(|n| n.to_owned()), @@ -118,7 +107,7 @@ impl CounterStorage for CrInMemoryStorage { }; // Process simple counters - for counter in counters.iter_mut() { + for counter in counters.iter() { let key = encode_counter_to_key(counter); // most of the time the counter should exist, so first try with a read only lock @@ -132,9 +121,7 @@ impl CounterStorage for CrInMemoryStorage { if let Some(limited) = process_counter(counter, store_value.value.read(), delta) { - if !load_counters { - return Ok(limited); - } + return Ok(limited); } counter_values_to_update.push(key); true @@ -157,10 +144,83 @@ impl CounterStorage for CrInMemoryStorage { })); if let Some(limited) = process_counter(counter, store_value.value.read(), delta) { - if !load_counters { - return Ok(limited); + return Ok(limited); + } + counter_values_to_update.push(key); + } + } + + // Update counters + let limits = self.limits.read().unwrap(); + counter_values_to_update.into_iter().for_each(|key| { + let store_value = limits.get(&key).unwrap(); + self.increment_counter(store_value.clone(), delta, now); + }); + + Ok(Authorization::Ok) + } + + #[tracing::instrument(skip_all)] + fn check_and_update_loading( + &self, + counters: &mut Vec, + delta: u64, + ) -> Result { + let mut first_limited = None; + let mut counter_values_to_update: Vec> = Vec::new(); + let now = SystemTime::now(); + + let mut process_counter = + |counter: &mut Counter, value: u64, delta: u64| -> Option { + let remaining = counter.max_value().checked_sub(value + delta); + counter.set_remaining(remaining.unwrap_or(0)); + if first_limited.is_none() && remaining.is_none() { + first_limited = Some(Authorization::Limited( + counter.limit().name().map(|n| n.to_owned()), + )); + } + if !Self::counter_is_within_limits(counter, Some(&value), delta) { + return Some(Authorization::Limited( + counter.limit().name().map(|n| n.to_owned()), + )); + } + None + }; + + // Process simple counters + for counter in counters.iter_mut() { + let key = encode_counter_to_key(counter); + + // most of the time the counter should exist, so first try with a read only lock + // since that will allow us to have higher concurrency + let counter_existed = { + let key = key.clone(); + let limits = self.limits.read().unwrap(); + match limits.get(&key) { + None => false, + Some(store_value) => { + let _ = process_counter(counter, store_value.value.read(), delta); + counter_values_to_update.push(key); + true } } + }; + + // we need to take the slow path since we need to mutate the limits map. + if !counter_existed { + // try again with a write lock to create the counter if it's still missing. + let mut limits = self.limits.write().unwrap(); + let store_value = limits.entry(key.clone()).or_insert(Arc::new(CounterEntry { + key: key.clone(), + counter: counter.clone(), + value: CrCounterValue::new( + self.identifier.clone(), + counter.max_value(), + counter.window(), + ), + })); + + let _ = process_counter(counter, store_value.value.read(), delta); counter_values_to_update.push(key); } } diff --git a/limitador/src/storage/in_memory.rs b/limitador/src/storage/in_memory.rs index ac5e2c07..dcf9480b 100644 --- a/limitador/src/storage/in_memory.rs +++ b/limitador/src/storage/in_memory.rs @@ -69,10 +69,71 @@ impl CounterStorage for InMemoryStorage { #[tracing::instrument(skip_all)] fn check_and_update( + &self, + counters: &[Counter], + delta: u64, + ) -> Result { + let limits_by_namespace = self.simple_limits.read().unwrap(); + let mut counter_values_to_update: Vec<(&AtomicExpiringValue, Duration)> = Vec::new(); + let mut qualified_counter_values_to_updated: Vec<(Arc, Duration)> = + Vec::new(); + let now = SystemTime::now(); + + let process_counter = + |counter: &Counter, value: u64, delta: u64| -> Option { + if !Self::counter_is_within_limits(counter, Some(&value), delta) { + return Some(Authorization::Limited( + counter.limit().name().map(|n| n.to_owned()), + )); + } + None + }; + + // Process simple counters + for counter in counters.iter().filter(|c| !c.is_qualified()) { + let atomic_expiring_value: &AtomicExpiringValue = + limits_by_namespace.get(counter.limit()).unwrap(); + + if let Some(limited) = process_counter(counter, atomic_expiring_value.value(), delta) { + return Ok(limited); + } + counter_values_to_update.push((atomic_expiring_value, counter.window())); + } + + // Process qualified counters + for counter in counters.iter().filter(|c| c.is_qualified()) { + let value = match self.qualified_counters.get(counter) { + None => self.qualified_counters.get_with_by_ref(counter, || { + Arc::new(AtomicExpiringValue::new(0, now + counter.window())) + }), + Some(counter) => counter, + }; + + if let Some(limited) = process_counter(counter, value.value(), delta) { + return Ok(limited); + } + + qualified_counter_values_to_updated.push((value, counter.window())); + } + + // Update counters + counter_values_to_update.iter().for_each(|(v, ttl)| { + v.update(delta, *ttl, now); + }); + qualified_counter_values_to_updated + .iter() + .for_each(|(v, ttl)| { + v.update(delta, *ttl, now); + }); + + Ok(Authorization::Ok) + } + + #[tracing::instrument(skip_all)] + fn check_and_update_loading( &self, counters: &mut Vec, delta: u64, - load_counters: bool, ) -> Result { let limits_by_namespace = self.simple_limits.read().unwrap(); let mut first_limited = None; @@ -83,14 +144,12 @@ impl CounterStorage for InMemoryStorage { let mut process_counter = |counter: &mut Counter, value: u64, delta: u64| -> Option { - if load_counters { - let remaining = counter.max_value().checked_sub(value + delta); - counter.set_remaining(remaining.unwrap_or_default()); - if first_limited.is_none() && remaining.is_none() { - first_limited = Some(Authorization::Limited( - counter.limit().name().map(|n| n.to_owned()), - )); - } + let remaining = counter.max_value().checked_sub(value + delta); + counter.set_remaining(remaining.unwrap_or_default()); + if first_limited.is_none() && remaining.is_none() { + first_limited = Some(Authorization::Limited( + counter.limit().name().map(|n| n.to_owned()), + )); } if !Self::counter_is_within_limits(counter, Some(&value), delta) { return Some(Authorization::Limited( @@ -105,11 +164,7 @@ impl CounterStorage for InMemoryStorage { let atomic_expiring_value: &AtomicExpiringValue = limits_by_namespace.get(counter.limit()).unwrap(); - if let Some(limited) = process_counter(counter, atomic_expiring_value.value(), delta) { - if !load_counters { - return Ok(limited); - } - } + let _ = process_counter(counter, atomic_expiring_value.value(), delta); counter_values_to_update.push((atomic_expiring_value, counter.window())); } @@ -122,12 +177,7 @@ impl CounterStorage for InMemoryStorage { Some(counter) => counter, }; - if let Some(limited) = process_counter(counter, value.value(), delta) { - if !load_counters { - return Ok(limited); - } - } - + let _ = process_counter(counter, value.value(), delta); qualified_counter_values_to_updated.push((value, counter.window())); } diff --git a/limitador/src/storage/mod.rs b/limitador/src/storage/mod.rs index 064cfb6c..dcf2d9f8 100644 --- a/limitador/src/storage/mod.rs +++ b/limitador/src/storage/mod.rs @@ -136,8 +136,11 @@ impl Storage { delta: u64, load_counters: bool, ) -> Result { - self.counters - .check_and_update(counters, delta, load_counters) + if load_counters { + self.counters.check_and_update_loading(counters, delta) + } else { + self.counters.check_and_update(counters, delta) + } } pub fn get_counters(&self, namespace: &Namespace) -> Result, StorageErr> { @@ -281,10 +284,14 @@ pub trait CounterStorage: Sync + Send { fn add_counter(&self, limit: &Limit) -> Result<(), StorageErr>; fn update_counter(&self, counter: &Counter, delta: u64) -> Result<(), StorageErr>; fn check_and_update( + &self, + counters: &[Counter], + delta: u64, + ) -> Result; + fn check_and_update_loading( &self, counters: &mut Vec, delta: u64, - load_counters: bool, ) -> Result; fn get_counters(&self, limits: &HashSet>) -> Result, StorageErr>; // todo revise typing here? fn delete_counters(&self, limits: &HashSet>) -> Result<(), StorageErr>; // todo revise typing here? diff --git a/limitador/src/storage/redis/redis_sync.rs b/limitador/src/storage/redis/redis_sync.rs index 9e136096..b9833ebc 100644 --- a/limitador/src/storage/redis/redis_sync.rs +++ b/limitador/src/storage/redis/redis_sync.rs @@ -57,40 +57,60 @@ impl CounterStorage for RedisStorage { #[tracing::instrument(skip_all)] fn check_and_update( &self, - counters: &mut Vec, + counters: &[Counter], delta: u64, - load_counters: bool, ) -> Result { let mut con = self.conn_pool.get()?; let counter_keys: Vec> = counters.iter().map(key_for_counter).collect(); - if load_counters { - let script = redis::Script::new(VALUES_AND_TTLS); - let mut script_invocation = script.prepare_invoke(); - for counter_key in &counter_keys { - script_invocation.key(counter_key); + let counter_vals: Vec> = redis::cmd("MGET") + .arg(counter_keys.clone()) + .query(&mut *con)?; + + for (i, counter) in counters.iter().enumerate() { + // remaining = max - (curr_val + delta) + let remaining = counter + .max_value() + .checked_sub(u64::try_from(counter_vals[i].unwrap_or(0)).unwrap_or(0) + delta); + if remaining.is_none() { + return Ok(Authorization::Limited( + counter.limit().name().map(|n| n.to_owned()), + )); } - let script_res: Vec> = script_invocation.invoke(&mut *con)?; + } - if let Some(res) = is_limited(counters, delta, script_res) { - return Ok(res); - } - } else { - let counter_vals: Vec> = redis::cmd("MGET") - .arg(counter_keys.clone()) - .query(&mut *con)?; - - for (i, counter) in counters.iter().enumerate() { - // remaining = max - (curr_val + delta) - let remaining = counter - .max_value() - .checked_sub(u64::try_from(counter_vals[i].unwrap_or(0)).unwrap_or(0) + delta); - if remaining.is_none() { - return Ok(Authorization::Limited( - counter.limit().name().map(|n| n.to_owned()), - )); - } - } + // TODO: this can be optimized by using pipelines with multiple updates + for (counter_idx, key) in counter_keys.into_iter().enumerate() { + let counter = &counters[counter_idx]; + redis::Script::new(SCRIPT_UPDATE_COUNTER) + .key(key) + .key(key_for_counters_of_limit(counter.limit())) + .arg(counter.window().as_secs()) + .arg(delta) + .invoke::<()>(&mut *con)?; + } + + Ok(Authorization::Ok) + } + + #[tracing::instrument(skip_all)] + fn check_and_update_loading( + &self, + counters: &mut Vec, + delta: u64, + ) -> Result { + let mut con = self.conn_pool.get()?; + let counter_keys: Vec> = counters.iter().map(key_for_counter).collect(); + + let script = redis::Script::new(VALUES_AND_TTLS); + let mut script_invocation = script.prepare_invoke(); + for counter_key in &counter_keys { + script_invocation.key(counter_key); + } + let script_res: Vec> = script_invocation.invoke(&mut *con)?; + + if let Some(res) = is_limited(counters, delta, script_res) { + return Ok(res); } // TODO: this can be optimized by using pipelines with multiple updates