diff --git a/limitador-server/src/envoy_rls/server.rs b/limitador-server/src/envoy_rls/server.rs index e6d5f08f..fafefa2c 100644 --- a/limitador-server/src/envoy_rls/server.rs +++ b/limitador-server/src/envoy_rls/server.rs @@ -3,13 +3,12 @@ use opentelemetry::propagation::Extractor; use std::collections::HashMap; use std::sync::Arc; +use limitador::CheckResult; use tonic::codegen::http::HeaderMap; use tonic::{transport, transport::Server, Request, Response, Status}; use tracing::Span; use tracing_opentelemetry::OpenTelemetrySpanExt; -use limitador::counter::Counter; - use crate::envoy_rls::server::envoy::config::core::v3::HeaderValue; use crate::envoy_rls::server::envoy::service::ratelimit::v3::rate_limit_response::Code; use crate::envoy_rls::server::envoy::service::ratelimit::v3::rate_limit_service_server::{ @@ -29,6 +28,21 @@ pub enum RateLimitHeaders { DraftVersion03, } +impl RateLimitHeaders { + pub fn headers(&self, response: &mut CheckResult) -> Vec { + let mut headers = match self { + RateLimitHeaders::None => Vec::default(), + RateLimitHeaders::DraftVersion03 => response + .response_header() + .into_iter() + .map(|(key, value)| HeaderValue { key, value }) + .collect(), + }; + headers.sort_by(|a, b| a.key.cmp(&b.key)); + headers + } +} + pub struct MyRateLimiter { limiter: Arc, rate_limit_headers: RateLimitHeaders, @@ -142,10 +156,7 @@ impl RateLimitService for MyRateLimiter { overall_code: resp_code.into(), statuses: vec![], request_headers_to_add: vec![], - response_headers_to_add: to_response_header( - &self.rate_limit_headers, - &mut rate_limited_resp.counters, - ), + response_headers_to_add: self.rate_limit_headers.headers(&mut rate_limited_resp), raw_body: vec![], dynamic_metadata: None, quota: None, @@ -155,58 +166,6 @@ impl RateLimitService for MyRateLimiter { } } -pub fn to_response_header( - rate_limit_headers: &RateLimitHeaders, - counters: &mut [Counter], -) -> Vec { - let mut headers = Vec::new(); - match rate_limit_headers { - RateLimitHeaders::None => {} - - // creates response headers per https://datatracker.ietf.org/doc/id/draft-polli-ratelimit-headers-03.html - RateLimitHeaders::DraftVersion03 => { - // sort by the limit remaining.. - counters.sort_by(|a, b| { - let a_remaining = a.remaining().unwrap_or(a.max_value()); - let b_remaining = b.remaining().unwrap_or(b.max_value()); - a_remaining.cmp(&b_remaining) - }); - - let mut all_limits_text = String::with_capacity(20 * counters.len()); - counters.iter_mut().for_each(|counter| { - all_limits_text.push_str( - format!(", {};w={}", counter.max_value(), counter.window().as_secs()).as_str(), - ); - if let Some(name) = counter.limit().name() { - all_limits_text - .push_str(format!(";name=\"{}\"", name.replace('"', "'")).as_str()); - } - }); - - if let Some(counter) = counters.first() { - headers.push(HeaderValue { - key: "X-RateLimit-Limit".to_string(), - value: format!("{}{}", counter.max_value(), all_limits_text), - }); - - let remaining = counter.remaining().unwrap_or(counter.max_value()); - headers.push(HeaderValue { - key: "X-RateLimit-Remaining".to_string(), - value: format!("{}", remaining), - }); - - if let Some(duration) = counter.expires_in() { - headers.push(HeaderValue { - key: "X-RateLimit-Reset".to_string(), - value: format!("{}", duration.as_secs()), - }); - } - } - } - }; - headers -} - struct RateLimitRequestHeaders { inner: HeaderMap, } diff --git a/limitador-server/src/http_api/server.rs b/limitador-server/src/http_api/server.rs index 97937d69..bc3b91e6 100644 --- a/limitador-server/src/http_api/server.rs +++ b/limitador-server/src/http_api/server.rs @@ -3,6 +3,7 @@ use crate::prometheus_metrics::PrometheusMetrics; use crate::Limiter; use actix_web::{http::StatusCode, HttpResponse, HttpResponseBuilder, ResponseError}; use actix_web::{App, HttpServer}; +use limitador::CheckResult; use paperclip::actix::{ api_v2_errors, api_v2_operation, @@ -209,7 +210,7 @@ async fn check_and_report( add_response_header( &mut resp, response_headers.as_str(), - &mut is_rate_limited.counters, + &mut is_rate_limited, ); resp.json(()) } @@ -224,7 +225,7 @@ async fn check_and_report( add_response_header( &mut resp, response_headers.as_str(), - &mut is_rate_limited.counters, + &mut is_rate_limited, ); resp.json(()) } @@ -238,48 +239,21 @@ async fn check_and_report( pub fn add_response_header( resp: &mut HttpResponseBuilder, rate_limit_headers: &str, - counters: &mut [limitador::counter::Counter], + result: &mut CheckResult, ) { - match rate_limit_headers { + if rate_limit_headers == "DraftVersion03" { // creates response headers per https://datatracker.ietf.org/doc/id/draft-polli-ratelimit-headers-03.html - "DraftVersion03" => { - // sort by the limit remaining.. - counters.sort_by(|a, b| { - let a_remaining = a.remaining().unwrap_or(a.max_value()); - let b_remaining = b.remaining().unwrap_or(b.max_value()); - a_remaining.cmp(&b_remaining) - }); - - let mut all_limits_text = String::with_capacity(20 * counters.len()); - counters.iter_mut().for_each(|counter| { - all_limits_text.push_str( - format!(", {};w={}", counter.max_value(), counter.window().as_secs()).as_str(), - ); - if let Some(name) = counter.limit().name() { - all_limits_text - .push_str(format!(";name=\"{}\"", name.replace('"', "'")).as_str()); - } - }); - - if let Some(counter) = counters.first() { - resp.insert_header(( - "X-RateLimit-Limit", - format!("{}{}", counter.max_value(), all_limits_text), - )); - - let remaining = counter.remaining().unwrap_or(counter.max_value()); - resp.insert_header(( - "X-RateLimit-Remaining".to_string(), - format!("{}", remaining), - )); - - if let Some(duration) = counter.expires_in() { - resp.insert_header(("X-RateLimit-Reset", format!("{}", duration.as_secs()))); - } + let headers = result.response_header(); + if let Some(limit) = headers.get("X-RateLimit-Limit") { + resp.insert_header(("X-RateLimit-Limit", limit.clone())); + } + if let Some(remaining) = headers.get("X-RateLimit-Remaining") { + resp.insert_header(("X-RateLimit-Remaining".to_string(), remaining.clone())); + if let Some(duration) = headers.get("X-RateLimit-Reset") { + resp.insert_header(("X-RateLimit-Reset", duration.clone())); } } - _default => {} - }; + } } pub async fn run_http_server( diff --git a/limitador/src/lib.rs b/limitador/src/lib.rs index 59f07a67..fdc4dc5f 100644 --- a/limitador/src/lib.rs +++ b/limitador/src/lib.rs @@ -226,6 +226,49 @@ pub struct CheckResult { pub limit_name: Option, } +impl CheckResult { + pub fn response_header(&mut self) -> HashMap { + let mut headers = HashMap::new(); + // sort by the limit remaining.. + self.counters.sort_by(|a, b| { + let a_remaining = a.remaining().unwrap_or(a.max_value()); + let b_remaining = b.remaining().unwrap_or(b.max_value()); + a_remaining.cmp(&b_remaining) + }); + + let mut all_limits_text = String::with_capacity(20 * self.counters.len()); + self.counters.iter_mut().for_each(|counter| { + all_limits_text.push_str( + format!(", {};w={}", counter.max_value(), counter.window().as_secs()).as_str(), + ); + if let Some(name) = counter.limit().name() { + all_limits_text.push_str(format!(";name=\"{}\"", name.replace('"', "'")).as_str()); + } + }); + + if let Some(counter) = self.counters.first() { + headers.insert( + "X-RateLimit-Limit".to_string(), + format!("{}{}", counter.max_value(), all_limits_text), + ); + + let remaining = counter.remaining().unwrap_or(counter.max_value()); + headers.insert( + "X-RateLimit-Remaining".to_string(), + format!("{}", remaining), + ); + + if let Some(duration) = counter.expires_in() { + headers.insert( + "X-RateLimit-Reset".to_string(), + format!("{}", duration.as_secs()), + ); + } + } + headers + } +} + impl From for bool { fn from(value: CheckResult) -> Self { value.limited