diff --git a/limitador/src/lib.rs b/limitador/src/lib.rs index 3606730f..fd125d38 100644 --- a/limitador/src/lib.rs +++ b/limitador/src/lib.rs @@ -694,3 +694,48 @@ fn classify_limits_by_namespace( res } + +#[cfg(test)] +mod test { + use crate::limit::Limit; + use crate::RateLimiter; + use std::collections::HashMap; + + #[test] + fn properly_updates_existing_limits() { + let rl = RateLimiter::new(100); + let namespace = "foo"; + + let l = Limit::new::<_, String>( + namespace, + 42, + 100, + Vec::::default(), + Vec::::default(), + ); + rl.add_limit(l.clone()); + let limits = rl.get_limits(&namespace.into()); + assert_eq!(limits.len(), 1); + assert!(limits.contains(&l)); + assert_eq!(limits.iter().next().unwrap().max_value(), 42); + + let r = rl + .check_rate_limited_and_update(&namespace.into(), &HashMap::default(), 1, true) + .unwrap(); + assert_eq!(r.counters.first().unwrap().max_value(), 42); + + let mut l = l.clone(); + l.set_max_value(50); + + rl.configure_with([l.clone()]).unwrap(); + let limits = rl.get_limits(&namespace.into()); + assert_eq!(limits.len(), 1); + assert!(limits.contains(&l)); + assert_eq!(limits.iter().next().unwrap().max_value(), 50); + + let r = rl + .check_rate_limited_and_update(&namespace.into(), &HashMap::default(), 1, true) + .unwrap(); + assert_eq!(r.counters.first().unwrap().max_value(), 50); + } +} diff --git a/limitador/src/limit.rs b/limitador/src/limit.rs index dc77f159..19dcaae7 100644 --- a/limitador/src/limit.rs +++ b/limitador/src/limit.rs @@ -49,7 +49,7 @@ impl From for Namespace { } } -#[derive(Eq, Debug, Clone, PartialOrd, Ord, Serialize, Deserialize)] +#[derive(Eq, Debug, Clone, Serialize, Deserialize)] pub struct Limit { #[serde(skip_serializing, default)] id: Option, @@ -426,6 +426,27 @@ impl Hash for Limit { } } +impl PartialOrd for Limit { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for Limit { + fn cmp(&self, other: &Self) -> Ordering { + match self.namespace.cmp(&other.namespace) { + Ordering::Equal => match self.seconds.cmp(&other.seconds) { + Ordering::Equal => match self.conditions.cmp(&other.conditions) { + Ordering::Equal => self.variables.cmp(&other.variables), + cmp => cmp, + }, + cmp => cmp, + }, + cmp => cmp, + } + } +} + impl PartialEq for Limit { fn eq(&self, other: &Self) -> bool { self.namespace == other.namespace @@ -833,6 +854,7 @@ mod conditions { #[cfg(test)] mod tests { use super::*; + use std::cmp::Ordering::Equal; #[test] fn limit_can_have_an_optional_name() { @@ -1027,4 +1049,28 @@ mod tests { assert_eq!(limit.id(), Some("test_id")) } + + #[test] + fn partial_equality() { + let limit1 = Limit::with_id( + "test_id", + "test_namespace", + 42, + 60, + vec!["req.method == 'GET'"], + vec!["app_id"], + ); + + let mut limit2 = Limit::new( + limit1.namespace.clone(), + limit1.max_value + 10, + limit1.seconds, + limit1.conditions.clone(), + limit1.variables.clone(), + ); + limit2.set_name("Who cares?".to_string()); + + assert_eq!(limit1.partial_cmp(&limit2), Some(Equal)); + assert_eq!(limit1, limit2); + } }