Skip to content

Commit

Permalink
Counter has Arc<Limit> to the limit, not a copy
Browse files Browse the repository at this point in the history
  • Loading branch information
alexsnaps committed Jun 4, 2024
1 parent eefd5d4 commit c3a9c09
Show file tree
Hide file tree
Showing 8 changed files with 75 additions and 38 deletions.
17 changes: 8 additions & 9 deletions limitador/src/counter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@ use crate::limit::{Limit, Namespace};
use serde::{Deserialize, Serialize, Serializer};
use std::collections::{BTreeMap, HashMap};
use std::hash::{Hash, Hasher};
use std::sync::Arc;
use std::time::Duration;

#[derive(Eq, Clone, Debug, Serialize, Deserialize)]
pub struct Counter {
limit: Limit,
limit: Arc<Limit>,

// Need to sort to generate the same object when using the JSON as a key or
// value in Redis.
Expand All @@ -26,9 +27,10 @@ where
}

impl Counter {
pub fn new(limit: Limit, set_variables: HashMap<String, String>) -> Self {
pub fn new<L: Into<Arc<Limit>>>(limit: L, set_variables: HashMap<String, String>) -> Self {
// TODO: check that all the variables defined in the limit are set.

let limit = limit.into();
let mut vars = set_variables;
vars.retain(|var, _| limit.has_variable(var));

Expand All @@ -43,7 +45,7 @@ impl Counter {
#[cfg(any(feature = "redis_storage", feature = "disk_storage"))]
pub(crate) fn key(&self) -> Self {
Self {
limit: self.limit.clone(),
limit: Arc::clone(&self.limit),
set_variables: self.set_variables.clone(),
remaining: None,
expires_in: None,
Expand All @@ -58,12 +60,9 @@ impl Counter {
self.limit.max_value()
}

pub fn update_to_limit(&mut self, limit: &Limit) -> bool {
if limit == &self.limit {
self.limit.set_max_value(limit.max_value());
if let Some(name) = limit.name() {
self.limit.set_name(name.to_string());
}
pub fn update_to_limit(&mut self, limit: Arc<Limit>) -> bool {
if limit == self.limit {
self.limit = Arc::clone(&limit);
return true;
}
false
Expand Down
21 changes: 15 additions & 6 deletions limitador/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@
#![allow(clippy::multiple_crate_versions)]

use std::collections::{HashMap, HashSet};
use std::sync::Arc;

use crate::counter::Counter;
use crate::errors::LimitadorError;
Expand Down Expand Up @@ -341,7 +342,11 @@ impl RateLimiter {
}

pub fn get_limits(&self, namespace: &Namespace) -> HashSet<Limit> {
self.storage.get_limits(namespace)
self.storage
.get_limits(namespace)
.iter()
.map(|l| (**l).clone())
.collect()
}

pub fn delete_limits(&self, namespace: &Namespace) -> Result<(), LimitadorError> {
Expand Down Expand Up @@ -475,12 +480,12 @@ impl RateLimiter {
namespace: &Namespace,
values: &HashMap<String, String>,
) -> Result<Vec<Counter>, LimitadorError> {
let limits = self.get_limits(namespace);
let limits = self.storage.get_limits(namespace);

let counters = limits
.iter()
.filter(|lim| lim.applies(values))
.map(|lim| Counter::new(lim.clone(), values.clone()))
.map(|lim| Counter::new(Arc::clone(lim), values.clone()))
.collect();

Ok(counters)
Expand Down Expand Up @@ -513,7 +518,11 @@ impl AsyncRateLimiter {
}

pub fn get_limits(&self, namespace: &Namespace) -> HashSet<Limit> {
self.storage.get_limits(namespace)
self.storage
.get_limits(namespace)
.iter()
.map(|l| (**l).clone())
.collect()
}

pub async fn delete_limits(&self, namespace: &Namespace) -> Result<(), LimitadorError> {
Expand Down Expand Up @@ -653,12 +662,12 @@ impl AsyncRateLimiter {
namespace: &Namespace,
values: &HashMap<String, String>,
) -> Result<Vec<Counter>, LimitadorError> {
let limits = self.get_limits(namespace);
let limits = self.storage.get_limits(namespace);

let counters = limits
.iter()
.filter(|lim| lim.applies(values))
.map(|lim| Counter::new(lim.clone(), values.clone()))
.map(|lim| Counter::new(Arc::clone(lim), values.clone()))
.collect();

Ok(counters)
Expand Down
2 changes: 1 addition & 1 deletion limitador/src/storage/disk/rocksdb_storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ impl CounterStorage for RocksDbStorage {
let value: ExpiringValue = value.as_ref().try_into()?;
for limit in limits {
if limit.deref() == counter.limit() {
counter.update_to_limit(limit);
counter.update_to_limit(Arc::clone(limit));
let ttl = value.ttl();
counter.set_expires_in(ttl);
counter.set_remaining(limit.max_value() - value.value());
Expand Down
5 changes: 3 additions & 2 deletions limitador/src/storage/keys.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

use crate::counter::Counter;
use crate::limit::Limit;
use std::sync::Arc;

pub fn key_for_counter(counter: &Counter) -> String {
if counter.remaining().is_some() || counter.expires_in().is_some() {
Expand Down Expand Up @@ -43,9 +44,9 @@ pub fn prefix_for_namespace(namespace: &str) -> String {
format!("namespace:{{{namespace}}},")
}

pub fn counter_from_counter_key(key: &str, limit: &Limit) -> Counter {
pub fn counter_from_counter_key(key: &str, limit: Arc<Limit>) -> Counter {
let mut counter = partial_counter_from_counter_key(key);
if !counter.update_to_limit(limit) {
if !counter.update_to_limit(Arc::clone(&limit)) {
// this means some kind of data corruption _or_ most probably
// an out of sync `impl PartialEq for Limit` vs `pub fn key_for_counter(counter: &Counter) -> String`
panic!(
Expand Down
42 changes: 30 additions & 12 deletions limitador/src/storage/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,17 +81,25 @@ impl Storage {
false
}

pub fn get_limits(&self, namespace: &Namespace) -> HashSet<Limit> {
pub fn get_limits(&self, namespace: &Namespace) -> HashSet<Arc<Limit>> {
match self.limits.read().unwrap().get(namespace) {
// todo revise typing here?
Some(limits) => limits.iter().map(|l| (**l).clone()).collect(),
Some(limits) => limits.iter().map(Arc::clone).collect(),
None => HashSet::new(),
}
}

pub fn delete_limit(&self, limit: &Limit) -> Result<(), StorageErr> {
let arc = match self.limits.read().unwrap().get(limit.namespace()) {
None => Arc::new(limit.clone()),
Some(limits) => limits
.iter()
.find(|l| ***l == *limit)
.cloned()
.unwrap_or_else(|| Arc::new(limit.clone())),
};
let mut limits = HashSet::new();
limits.insert(Arc::new(limit.clone()));
limits.insert(arc);
self.counters.delete_counters(&limits)?;

let mut limits = self.limits.write().unwrap();
Expand Down Expand Up @@ -190,17 +198,25 @@ impl AsyncStorage {
false
}

pub fn get_limits(&self, namespace: &Namespace) -> HashSet<Limit> {
pub fn get_limits(&self, namespace: &Namespace) -> HashSet<Arc<Limit>> {
match self.limits.read().unwrap().get(namespace) {
Some(limits) => limits.iter().map(|l| (**l).clone()).collect(),
Some(limits) => limits.iter().map(Arc::clone).collect(),
None => HashSet::new(),
}
}

pub async fn delete_limit(&self, limit: &Limit) -> Result<(), StorageErr> {
let arc = match self.limits.read().unwrap().get(limit.namespace()) {
None => Arc::new(limit.clone()),
Some(limits) => limits
.iter()
.find(|l| ***l == *limit)
.cloned()
.unwrap_or_else(|| Arc::new(limit.clone())),
};
let mut limits = HashSet::new();
limits.insert(limit.clone());
self.counters.delete_counters(limits).await?;
limits.insert(arc);
self.counters.delete_counters(&limits).await?;

let mut limits_for_namespace = self.limits.write().unwrap();

Expand All @@ -217,8 +233,7 @@ impl AsyncStorage {
pub async fn delete_limits(&self, namespace: &Namespace) -> Result<(), StorageErr> {
let option = { self.limits.write().unwrap().remove(namespace) };
if let Some(data) = option {
let limits = data.iter().map(|l| (**l).clone()).collect();
self.counters.delete_counters(limits).await?;
self.counters.delete_counters(&data).await?;
}
Ok(())
}
Expand Down Expand Up @@ -251,7 +266,7 @@ impl AsyncStorage {
namespace: &Namespace,
) -> Result<HashSet<Counter>, StorageErr> {
let limits = self.get_limits(namespace);
self.counters.get_counters(limits).await
self.counters.get_counters(&limits).await
}

pub async fn clear(&self) -> Result<(), StorageErr> {
Expand Down Expand Up @@ -285,8 +300,11 @@ pub trait AsyncCounterStorage: Sync + Send {
delta: u64,
load_counters: bool,
) -> Result<Authorization, StorageErr>;
async fn get_counters(&self, limits: HashSet<Limit>) -> Result<HashSet<Counter>, StorageErr>;
async fn delete_counters(&self, limits: HashSet<Limit>) -> Result<(), StorageErr>;
async fn get_counters(
&self,
limits: &HashSet<Arc<Limit>>,
) -> Result<HashSet<Counter>, StorageErr>;
async fn delete_counters(&self, limits: &HashSet<Arc<Limit>>) -> Result<(), StorageErr>;
async fn clear(&self) -> Result<(), StorageErr>;
}

Expand Down
16 changes: 11 additions & 5 deletions limitador/src/storage/redis/redis_async.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ use crate::storage::{AsyncCounterStorage, Authorization, StorageErr};
use async_trait::async_trait;
use redis::{AsyncCommands, RedisError};
use std::collections::HashSet;
use std::ops::Deref;
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use tracing::{debug_span, Instrument};

Expand Down Expand Up @@ -127,20 +129,24 @@ impl AsyncCounterStorage for AsyncRedisStorage {
}

#[tracing::instrument(skip_all)]
async fn get_counters(&self, limits: HashSet<Limit>) -> Result<HashSet<Counter>, StorageErr> {
async fn get_counters(
&self,
limits: &HashSet<Arc<Limit>>,
) -> Result<HashSet<Counter>, StorageErr> {
let mut res = HashSet::new();

let mut con = self.conn_manager.clone();

for limit in limits {
let counter_keys = {
con.smembers::<String, HashSet<String>>(key_for_counters_of_limit(&limit))
con.smembers::<String, HashSet<String>>(key_for_counters_of_limit(limit))
.instrument(debug_span!("datastore"))
.await?
};

for counter_key in counter_keys {
let mut counter: Counter = counter_from_counter_key(&counter_key, &limit);
let mut counter: Counter =
counter_from_counter_key(&counter_key, Arc::clone(limit));

// If the key does not exist, it means that the counter expired,
// so we don't have to return it.
Expand Down Expand Up @@ -172,9 +178,9 @@ impl AsyncCounterStorage for AsyncRedisStorage {
}

#[tracing::instrument(skip_all)]
async fn delete_counters(&self, limits: HashSet<Limit>) -> Result<(), StorageErr> {
async fn delete_counters(&self, limits: &HashSet<Arc<Limit>>) -> Result<(), StorageErr> {
for limit in limits {
self.delete_counters_associated_with_limit(&limit)
self.delete_counters_associated_with_limit(limit.deref())
.instrument(debug_span!("datastore"))
.await?
}
Expand Down
7 changes: 5 additions & 2 deletions limitador/src/storage/redis/redis_cached.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,12 +132,15 @@ impl AsyncCounterStorage for CachedRedisStorage {
}

#[tracing::instrument(skip_all)]
async fn get_counters(&self, limits: HashSet<Limit>) -> Result<HashSet<Counter>, StorageErr> {
async fn get_counters(
&self,
limits: &HashSet<Arc<Limit>>,
) -> Result<HashSet<Counter>, StorageErr> {
self.async_redis_storage.get_counters(limits).await
}

#[tracing::instrument(skip_all)]
async fn delete_counters(&self, limits: HashSet<Limit>) -> Result<(), StorageErr> {
async fn delete_counters(&self, limits: &HashSet<Arc<Limit>>) -> Result<(), StorageErr> {
self.async_redis_storage.delete_counters(limits).await
}

Expand Down
3 changes: 2 additions & 1 deletion limitador/src/storage/redis/redis_sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ impl CounterStorage for RedisStorage {
con.smembers::<String, HashSet<String>>(key_for_counters_of_limit(limit))?;

for counter_key in counter_keys {
let mut counter: Counter = counter_from_counter_key(&counter_key, limit);
let mut counter: Counter =
counter_from_counter_key(&counter_key, Arc::clone(limit));

// If the key does not exist, it means that the counter expired,
// so we don't have to return it.
Expand Down

0 comments on commit c3a9c09

Please sign in to comment.