From c3a9c0929ae73403872c6489f16c95387a8bf6c2 Mon Sep 17 00:00:00 2001 From: Alex Snaps Date: Tue, 4 Jun 2024 15:01:52 -0400 Subject: [PATCH] Counter has `Arc` to the limit, not a copy --- limitador/src/counter.rs | 17 ++++---- limitador/src/lib.rs | 21 +++++++--- limitador/src/storage/disk/rocksdb_storage.rs | 2 +- limitador/src/storage/keys.rs | 5 ++- limitador/src/storage/mod.rs | 42 +++++++++++++------ limitador/src/storage/redis/redis_async.rs | 16 ++++--- limitador/src/storage/redis/redis_cached.rs | 7 +++- limitador/src/storage/redis/redis_sync.rs | 3 +- 8 files changed, 75 insertions(+), 38 deletions(-) diff --git a/limitador/src/counter.rs b/limitador/src/counter.rs index 9763d627..ed6cd999 100644 --- a/limitador/src/counter.rs +++ b/limitador/src/counter.rs @@ -2,11 +2,12 @@ use crate::limit::{Limit, Namespace}; use serde::{Deserialize, Serialize, Serializer}; use std::collections::{BTreeMap, HashMap}; use std::hash::{Hash, Hasher}; +use std::sync::Arc; use std::time::Duration; #[derive(Eq, Clone, Debug, Serialize, Deserialize)] pub struct Counter { - limit: Limit, + limit: Arc, // Need to sort to generate the same object when using the JSON as a key or // value in Redis. @@ -26,9 +27,10 @@ where } impl Counter { - pub fn new(limit: Limit, set_variables: HashMap) -> Self { + pub fn new>>(limit: L, set_variables: HashMap) -> Self { // TODO: check that all the variables defined in the limit are set. + let limit = limit.into(); let mut vars = set_variables; vars.retain(|var, _| limit.has_variable(var)); @@ -43,7 +45,7 @@ impl Counter { #[cfg(any(feature = "redis_storage", feature = "disk_storage"))] pub(crate) fn key(&self) -> Self { Self { - limit: self.limit.clone(), + limit: Arc::clone(&self.limit), set_variables: self.set_variables.clone(), remaining: None, expires_in: None, @@ -58,12 +60,9 @@ impl Counter { self.limit.max_value() } - pub fn update_to_limit(&mut self, limit: &Limit) -> bool { - if limit == &self.limit { - self.limit.set_max_value(limit.max_value()); - if let Some(name) = limit.name() { - self.limit.set_name(name.to_string()); - } + pub fn update_to_limit(&mut self, limit: Arc) -> bool { + if limit == self.limit { + self.limit = Arc::clone(&limit); return true; } false diff --git a/limitador/src/lib.rs b/limitador/src/lib.rs index fdc4dc5f..a71de204 100644 --- a/limitador/src/lib.rs +++ b/limitador/src/lib.rs @@ -193,6 +193,7 @@ #![allow(clippy::multiple_crate_versions)] use std::collections::{HashMap, HashSet}; +use std::sync::Arc; use crate::counter::Counter; use crate::errors::LimitadorError; @@ -341,7 +342,11 @@ impl RateLimiter { } pub fn get_limits(&self, namespace: &Namespace) -> HashSet { - self.storage.get_limits(namespace) + self.storage + .get_limits(namespace) + .iter() + .map(|l| (**l).clone()) + .collect() } pub fn delete_limits(&self, namespace: &Namespace) -> Result<(), LimitadorError> { @@ -475,12 +480,12 @@ impl RateLimiter { namespace: &Namespace, values: &HashMap, ) -> Result, LimitadorError> { - let limits = self.get_limits(namespace); + let limits = self.storage.get_limits(namespace); let counters = limits .iter() .filter(|lim| lim.applies(values)) - .map(|lim| Counter::new(lim.clone(), values.clone())) + .map(|lim| Counter::new(Arc::clone(lim), values.clone())) .collect(); Ok(counters) @@ -513,7 +518,11 @@ impl AsyncRateLimiter { } pub fn get_limits(&self, namespace: &Namespace) -> HashSet { - self.storage.get_limits(namespace) + self.storage + .get_limits(namespace) + .iter() + .map(|l| (**l).clone()) + .collect() } pub async fn delete_limits(&self, namespace: &Namespace) -> Result<(), LimitadorError> { @@ -653,12 +662,12 @@ impl AsyncRateLimiter { namespace: &Namespace, values: &HashMap, ) -> Result, LimitadorError> { - let limits = self.get_limits(namespace); + let limits = self.storage.get_limits(namespace); let counters = limits .iter() .filter(|lim| lim.applies(values)) - .map(|lim| Counter::new(lim.clone(), values.clone())) + .map(|lim| Counter::new(Arc::clone(lim), values.clone())) .collect(); Ok(counters) diff --git a/limitador/src/storage/disk/rocksdb_storage.rs b/limitador/src/storage/disk/rocksdb_storage.rs index 83a64771..148af984 100644 --- a/limitador/src/storage/disk/rocksdb_storage.rs +++ b/limitador/src/storage/disk/rocksdb_storage.rs @@ -116,7 +116,7 @@ impl CounterStorage for RocksDbStorage { let value: ExpiringValue = value.as_ref().try_into()?; for limit in limits { if limit.deref() == counter.limit() { - counter.update_to_limit(limit); + counter.update_to_limit(Arc::clone(limit)); let ttl = value.ttl(); counter.set_expires_in(ttl); counter.set_remaining(limit.max_value() - value.value()); diff --git a/limitador/src/storage/keys.rs b/limitador/src/storage/keys.rs index 6d32977c..81d818c6 100644 --- a/limitador/src/storage/keys.rs +++ b/limitador/src/storage/keys.rs @@ -14,6 +14,7 @@ use crate::counter::Counter; use crate::limit::Limit; +use std::sync::Arc; pub fn key_for_counter(counter: &Counter) -> String { if counter.remaining().is_some() || counter.expires_in().is_some() { @@ -43,9 +44,9 @@ pub fn prefix_for_namespace(namespace: &str) -> String { format!("namespace:{{{namespace}}},") } -pub fn counter_from_counter_key(key: &str, limit: &Limit) -> Counter { +pub fn counter_from_counter_key(key: &str, limit: Arc) -> Counter { let mut counter = partial_counter_from_counter_key(key); - if !counter.update_to_limit(limit) { + if !counter.update_to_limit(Arc::clone(&limit)) { // this means some kind of data corruption _or_ most probably // an out of sync `impl PartialEq for Limit` vs `pub fn key_for_counter(counter: &Counter) -> String` panic!( diff --git a/limitador/src/storage/mod.rs b/limitador/src/storage/mod.rs index da2fcdcf..403d21a6 100644 --- a/limitador/src/storage/mod.rs +++ b/limitador/src/storage/mod.rs @@ -81,17 +81,25 @@ impl Storage { false } - pub fn get_limits(&self, namespace: &Namespace) -> HashSet { + pub fn get_limits(&self, namespace: &Namespace) -> HashSet> { match self.limits.read().unwrap().get(namespace) { // todo revise typing here? - Some(limits) => limits.iter().map(|l| (**l).clone()).collect(), + Some(limits) => limits.iter().map(Arc::clone).collect(), None => HashSet::new(), } } pub fn delete_limit(&self, limit: &Limit) -> Result<(), StorageErr> { + let arc = match self.limits.read().unwrap().get(limit.namespace()) { + None => Arc::new(limit.clone()), + Some(limits) => limits + .iter() + .find(|l| ***l == *limit) + .cloned() + .unwrap_or_else(|| Arc::new(limit.clone())), + }; let mut limits = HashSet::new(); - limits.insert(Arc::new(limit.clone())); + limits.insert(arc); self.counters.delete_counters(&limits)?; let mut limits = self.limits.write().unwrap(); @@ -190,17 +198,25 @@ impl AsyncStorage { false } - pub fn get_limits(&self, namespace: &Namespace) -> HashSet { + pub fn get_limits(&self, namespace: &Namespace) -> HashSet> { match self.limits.read().unwrap().get(namespace) { - Some(limits) => limits.iter().map(|l| (**l).clone()).collect(), + Some(limits) => limits.iter().map(Arc::clone).collect(), None => HashSet::new(), } } pub async fn delete_limit(&self, limit: &Limit) -> Result<(), StorageErr> { + let arc = match self.limits.read().unwrap().get(limit.namespace()) { + None => Arc::new(limit.clone()), + Some(limits) => limits + .iter() + .find(|l| ***l == *limit) + .cloned() + .unwrap_or_else(|| Arc::new(limit.clone())), + }; let mut limits = HashSet::new(); - limits.insert(limit.clone()); - self.counters.delete_counters(limits).await?; + limits.insert(arc); + self.counters.delete_counters(&limits).await?; let mut limits_for_namespace = self.limits.write().unwrap(); @@ -217,8 +233,7 @@ impl AsyncStorage { pub async fn delete_limits(&self, namespace: &Namespace) -> Result<(), StorageErr> { let option = { self.limits.write().unwrap().remove(namespace) }; if let Some(data) = option { - let limits = data.iter().map(|l| (**l).clone()).collect(); - self.counters.delete_counters(limits).await?; + self.counters.delete_counters(&data).await?; } Ok(()) } @@ -251,7 +266,7 @@ impl AsyncStorage { namespace: &Namespace, ) -> Result, StorageErr> { let limits = self.get_limits(namespace); - self.counters.get_counters(limits).await + self.counters.get_counters(&limits).await } pub async fn clear(&self) -> Result<(), StorageErr> { @@ -285,8 +300,11 @@ pub trait AsyncCounterStorage: Sync + Send { delta: u64, load_counters: bool, ) -> Result; - async fn get_counters(&self, limits: HashSet) -> Result, StorageErr>; - async fn delete_counters(&self, limits: HashSet) -> Result<(), StorageErr>; + async fn get_counters( + &self, + limits: &HashSet>, + ) -> Result, StorageErr>; + async fn delete_counters(&self, limits: &HashSet>) -> Result<(), StorageErr>; async fn clear(&self) -> Result<(), StorageErr>; } diff --git a/limitador/src/storage/redis/redis_async.rs b/limitador/src/storage/redis/redis_async.rs index 18175c75..d29e7b3a 100644 --- a/limitador/src/storage/redis/redis_async.rs +++ b/limitador/src/storage/redis/redis_async.rs @@ -11,7 +11,9 @@ use crate::storage::{AsyncCounterStorage, Authorization, StorageErr}; use async_trait::async_trait; use redis::{AsyncCommands, RedisError}; use std::collections::HashSet; +use std::ops::Deref; use std::str::FromStr; +use std::sync::Arc; use std::time::Duration; use tracing::{debug_span, Instrument}; @@ -127,20 +129,24 @@ impl AsyncCounterStorage for AsyncRedisStorage { } #[tracing::instrument(skip_all)] - async fn get_counters(&self, limits: HashSet) -> Result, StorageErr> { + async fn get_counters( + &self, + limits: &HashSet>, + ) -> Result, StorageErr> { let mut res = HashSet::new(); let mut con = self.conn_manager.clone(); for limit in limits { let counter_keys = { - con.smembers::>(key_for_counters_of_limit(&limit)) + con.smembers::>(key_for_counters_of_limit(limit)) .instrument(debug_span!("datastore")) .await? }; for counter_key in counter_keys { - let mut counter: Counter = counter_from_counter_key(&counter_key, &limit); + let mut counter: Counter = + counter_from_counter_key(&counter_key, Arc::clone(limit)); // If the key does not exist, it means that the counter expired, // so we don't have to return it. @@ -172,9 +178,9 @@ impl AsyncCounterStorage for AsyncRedisStorage { } #[tracing::instrument(skip_all)] - async fn delete_counters(&self, limits: HashSet) -> Result<(), StorageErr> { + async fn delete_counters(&self, limits: &HashSet>) -> Result<(), StorageErr> { for limit in limits { - self.delete_counters_associated_with_limit(&limit) + self.delete_counters_associated_with_limit(limit.deref()) .instrument(debug_span!("datastore")) .await? } diff --git a/limitador/src/storage/redis/redis_cached.rs b/limitador/src/storage/redis/redis_cached.rs index fbf4aa89..9a3ae681 100644 --- a/limitador/src/storage/redis/redis_cached.rs +++ b/limitador/src/storage/redis/redis_cached.rs @@ -132,12 +132,15 @@ impl AsyncCounterStorage for CachedRedisStorage { } #[tracing::instrument(skip_all)] - async fn get_counters(&self, limits: HashSet) -> Result, StorageErr> { + async fn get_counters( + &self, + limits: &HashSet>, + ) -> Result, StorageErr> { self.async_redis_storage.get_counters(limits).await } #[tracing::instrument(skip_all)] - async fn delete_counters(&self, limits: HashSet) -> Result<(), StorageErr> { + async fn delete_counters(&self, limits: &HashSet>) -> Result<(), StorageErr> { self.async_redis_storage.delete_counters(limits).await } diff --git a/limitador/src/storage/redis/redis_sync.rs b/limitador/src/storage/redis/redis_sync.rs index 82141236..81eb3f11 100644 --- a/limitador/src/storage/redis/redis_sync.rs +++ b/limitador/src/storage/redis/redis_sync.rs @@ -118,7 +118,8 @@ impl CounterStorage for RedisStorage { con.smembers::>(key_for_counters_of_limit(limit))?; for counter_key in counter_keys { - let mut counter: Counter = counter_from_counter_key(&counter_key, limit); + let mut counter: Counter = + counter_from_counter_key(&counter_key, Arc::clone(limit)); // If the key does not exist, it means that the counter expired, // so we don't have to return it.