diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index c9f0c0a..6cc2c3d 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -44,6 +44,16 @@ jobs: rust: stable env: RUST_BACKTRACE: full + services: + redis: + image: redis:7.2-alpine + options: >- + --health-cmd "redis-cli ping" + --health-interval 10s + --health-timeout 5s + --health-retries 5 + ports: + - 6379:6379 steps: - uses: actions/checkout@v3 diff --git a/Cargo.toml b/Cargo.toml index e8c9dfa..fcd526f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,6 +23,7 @@ full = [ "geoip", "http", "metrics", + "rate_limit", ] alloc = ["dep:alloc"] analytics = ["dep:analytics"] @@ -33,6 +34,7 @@ geoip = ["dep:geoip"] http = [] metrics = ["dep:metrics", "future/metrics", "alloc/metrics", "http/metrics"] profiler = ["alloc/profiler"] +rate_limit = ["dep:rate_limit"] [workspace.dependencies] aws-sdk-s3 = "1.13" @@ -45,6 +47,7 @@ future = { path = "./crates/future", optional = true } geoip = { path = "./crates/geoip", optional = true } http = { path = "./crates/http", optional = true } metrics = { path = "./crates/metrics", optional = true } +rate_limit = { path = "./crates/rate_limit", optional = true } [dev-dependencies] anyhow = "1" diff --git a/crates/rate_limit/Cargo.toml b/crates/rate_limit/Cargo.toml new file mode 100644 index 0000000..cdee74d --- /dev/null +++ b/crates/rate_limit/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "rate_limit" +version = "0.1.0" +edition = "2021" + +[dependencies] +chrono = { version = "0.4", features = ["serde"] } +deadpool-redis = "0.14" +moka = { version = "0.12", features = ["future"] } +redis = { version = "0.24", default-features = false, features = ["script"] } +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +thiserror = "1.0" +tracing = "0.1" + +[dev-dependencies] +futures = "0.3" +tokio = { version = "1", features = ["full"] } +uuid = "1.8" diff --git a/crates/rate_limit/docker-compose.yml b/crates/rate_limit/docker-compose.yml new file mode 100644 index 0000000..0da745b --- /dev/null +++ b/crates/rate_limit/docker-compose.yml @@ -0,0 +1,6 @@ +version: "3.9" +services: + redis: + image: redis:7.0 + ports: + - "6379:6379" diff --git a/crates/rate_limit/src/lib.rs b/crates/rate_limit/src/lib.rs new file mode 100644 index 0000000..301ed0b --- /dev/null +++ b/crates/rate_limit/src/lib.rs @@ -0,0 +1,279 @@ +use { + chrono::{DateTime, Duration, Utc}, + deadpool_redis::{Pool, PoolError}, + moka::future::Cache, + redis::{RedisError, Script}, + std::{collections::HashMap, sync::Arc}, +}; + +#[derive(Debug, thiserror::Error)] +#[error("Rate limit exceeded. Try again at {reset}")] +pub struct RateLimitExceeded { + reset: u64, +} + +#[derive(Debug, thiserror::Error)] +pub enum InternalRateLimitError { + #[error("Redis pool error {0}")] + Pool(PoolError), + + #[error("Redis error: {0}")] + Redis(RedisError), +} + +#[derive(Debug, thiserror::Error)] +pub enum RateLimitError { + #[error(transparent)] + RateLimitExceeded(RateLimitExceeded), + + #[error("Internal error: {0}")] + Internal(InternalRateLimitError), +} + +/// Rate limit check using a token bucket algorithm for one key and in-memory +/// cache for rate-limited keys. `mem_cache` TTL must be set to the same value +/// as the refill interval. +pub async fn token_bucket( + mem_cache: &Cache, + redis_write_pool: &Arc, + key: String, + max_tokens: u32, + interval: Duration, + refill_rate: u32, + now_millis: DateTime, +) -> Result<(), RateLimitError> { + // Check if the key is in the memory cache of rate limited keys + // to omit the redis RTT in case of flood + if let Some(reset) = mem_cache.get(&key).await { + return Err(RateLimitError::RateLimitExceeded(RateLimitExceeded { + reset, + })); + } + + let result = token_bucket_many( + redis_write_pool, + vec![key.clone()], + max_tokens, + interval, + refill_rate, + now_millis, + ) + .await + .map_err(RateLimitError::Internal)?; + + let (remaining, reset) = result.get(&key).expect("Should contain the key"); + if remaining.is_negative() { + let reset_interval = reset / 1000; + + // Insert the rate-limited key into the memory cache to avoid the redis RTT in + // case of flood + mem_cache.insert(key, reset_interval).await; + + Err(RateLimitError::RateLimitExceeded(RateLimitExceeded { + reset: reset_interval, + })) + } else { + Ok(()) + } +} + +/// Rate limit check using a token bucket algorithm for many keys. +pub async fn token_bucket_many( + redis_write_pool: &Arc, + keys: Vec, + max_tokens: u32, + interval: Duration, + refill_rate: u32, + now_millis: DateTime, +) -> Result, InternalRateLimitError> { + // Remaining is number of tokens remaining. -1 for rate limited. + // Reset is the time at which there will be 1 more token than before. This + // could, for example, be used to cache a 0 token count. + Script::new(include_str!("token_bucket.lua")) + .key(keys) + .arg(max_tokens) + .arg(interval.num_milliseconds()) + .arg(refill_rate) + .arg(now_millis.timestamp_millis()) + .invoke_async::<_, String>( + &mut redis_write_pool + .clone() + .get() + .await + .map_err(InternalRateLimitError::Pool)?, + ) + .await + .map_err(InternalRateLimitError::Redis) + .map(|value| serde_json::from_str(&value).expect("Redis script should return valid JSON")) +} + +#[cfg(test)] +mod tests { + const REDIS_URI: &str = "redis://localhost:6379"; + const REFILL_INTERVAL_MILLIS: i64 = 100; + const MAX_TOKENS: u32 = 5; + const REFILL_RATE: u32 = 1; + + use { + super::*, + chrono::Utc, + deadpool_redis::{Config, Runtime}, + redis::AsyncCommands, + tokio::time::sleep, + uuid::Uuid, + }; + + async fn redis_clear_keys(conn_uri: &str, keys: &[String]) { + let client = redis::Client::open(conn_uri).unwrap(); + let mut conn = client.get_async_connection().await.unwrap(); + for key in keys { + let _: () = conn.del(key).await.unwrap(); + } + } + + async fn test_rate_limiting(key: String) { + let cfg = Config::from_url(REDIS_URI); + let pool = Arc::new(cfg.create_pool(Some(Runtime::Tokio1)).unwrap()); + let refill_interval = chrono::Duration::try_milliseconds(REFILL_INTERVAL_MILLIS).unwrap(); + let rate_limit = |now_millis| { + let key = key.clone(); + let pool = pool.clone(); + async move { + token_bucket_many( + &pool, + vec![key.clone()], + MAX_TOKENS, + refill_interval, + REFILL_RATE, + now_millis, + ) + .await + .unwrap() + .get(&key) + .unwrap() + .to_owned() + } + }; + // Function to call rate limit multiple times and assert results + // for tokens count and reset timestamp + let call_rate_limit_loop = |loop_iterations| async move { + let first_call_millis = Utc::now(); + for i in 0..=loop_iterations { + let curr_iter = loop_iterations as i64 - i as i64 - 1; + + // Using the first call timestamp for the first call or produce the current + let result = if i == 0 { + rate_limit(first_call_millis).await + } else { + rate_limit(Utc::now()).await + }; + + // Assert the remaining tokens count + assert_eq!(result.0, curr_iter); + // Assert the reset timestamp should be the first call timestamp + refill + // interval + assert_eq!( + result.1, + (first_call_millis.timestamp_millis() + REFILL_INTERVAL_MILLIS) as u64 + ); + } + // Returning the refill timestamp + first_call_millis.timestamp_millis() + REFILL_INTERVAL_MILLIS + }; + + // Call rate limit until max tokens limit is reached + call_rate_limit_loop(MAX_TOKENS).await; + + // Sleep for the full refill and try again + // Tokens numbers should be the same as the previous iteration because + // they were fully refilled + sleep((refill_interval * MAX_TOKENS as i32).to_std().unwrap()).await; + let last_timestamp = call_rate_limit_loop(MAX_TOKENS).await; + + // Sleep for just one refill and try again + // The result must contain one token and the reset timestamp should be + // the last full iteration call timestamp + refill interval + sleep((refill_interval).to_std().unwrap()).await; + let result = rate_limit(Utc::now()).await; + assert_eq!(result.0, 0); + assert_eq!(result.1, (last_timestamp + REFILL_INTERVAL_MILLIS) as u64); + } + + #[tokio::test] + async fn test_token_bucket_many() { + const KEYS_NUMBER_TO_TEST: usize = 3; + let keys = (0..KEYS_NUMBER_TO_TEST) + .map(|_| Uuid::new_v4().to_string()) + .collect::>(); + + // Before running the test, ensure the test keys are cleared + redis_clear_keys(REDIS_URI, &keys).await; + + // Start async test for each key and wait for all to complete + let tasks = keys.iter().map(|key| test_rate_limiting(key.clone())); + futures::future::join_all(tasks).await; + + // Clear keys after the test + redis_clear_keys(REDIS_URI, &keys).await; + } + + #[tokio::test] + async fn test_token_bucket() { + // Create Moka cache with a TTL of the refill interval + let cache: Cache = Cache::builder() + .time_to_live(std::time::Duration::from_millis( + REFILL_INTERVAL_MILLIS as u64, + )) + .build(); + + let cfg = Config::from_url(REDIS_URI); + let pool = Arc::new(cfg.create_pool(Some(Runtime::Tokio1)).unwrap()); + let key = Uuid::new_v4().to_string(); + + // Before running the test, ensure the test keys are cleared + redis_clear_keys(REDIS_URI, &[key.clone()]).await; + + let refill_interval = chrono::Duration::try_milliseconds(REFILL_INTERVAL_MILLIS).unwrap(); + let rate_limit = |now_millis| { + let key = key.clone(); + let pool = pool.clone(); + let cache = cache.clone(); + async move { + token_bucket( + &cache, + &pool, + key.clone(), + MAX_TOKENS, + refill_interval, + REFILL_RATE, + now_millis, + ) + .await + } + }; + let call_rate_limit_loop = |now_millis| async move { + for i in 0..=MAX_TOKENS { + let result = rate_limit(now_millis).await; + if i == MAX_TOKENS { + assert!(result + .err() + .unwrap() + .to_string() + .contains("Rate limit exceeded")); + } else { + assert!(result.is_ok()); + } + } + }; + + // Call rate limit until max tokens limit is reached + call_rate_limit_loop(Utc::now()).await; + + // Sleep for refill and try again + sleep((refill_interval * MAX_TOKENS as i32).to_std().unwrap()).await; + call_rate_limit_loop(Utc::now()).await; + + // Clear keys after the test + redis_clear_keys(REDIS_URI, &[key.clone()]).await; + } +} diff --git a/crates/rate_limit/src/token_bucket.lua b/crates/rate_limit/src/token_bucket.lua new file mode 100644 index 0000000..07ec7b1 --- /dev/null +++ b/crates/rate_limit/src/token_bucket.lua @@ -0,0 +1,44 @@ +-- Adapted from https://github.com/upstash/ratelimit/blob/3a8cfb00e827188734ac347965cb743a75fcb98a/src/single.ts#L311 +local keys = KEYS -- identifier including prefixes +local maxTokens = tonumber(ARGV[1]) -- maximum number of tokens +local interval = tonumber(ARGV[2]) -- size of the window in milliseconds +local refillRate = tonumber(ARGV[3]) -- how many tokens are refilled after each interval +local now = tonumber(ARGV[4]) -- current timestamp in milliseconds + +local results = {} + +for i, key in ipairs(keys) do + local bucket = redis.call("HMGET", key, "refilledAt", "tokens") + + local refilledAt + local tokens + + if bucket[1] == false then + refilledAt = now + tokens = maxTokens + else + refilledAt = tonumber(bucket[1]) + tokens = tonumber(bucket[2]) + end + + if now >= refilledAt + interval then + local numRefills = math.floor((now - refilledAt) / interval) + tokens = math.min(maxTokens, tokens + numRefills * refillRate) + + refilledAt = refilledAt + numRefills * interval + end + + if tokens == 0 then + results[key] = {-1, refilledAt + interval} + else + local remaining = tokens - 1 + local expireAt = math.ceil(((maxTokens - remaining) / refillRate)) * interval + + redis.call("HSET", key, "refilledAt", refilledAt, "tokens", remaining) + redis.call("PEXPIRE", key, expireAt) + results[key] = {remaining, refilledAt + interval} + end +end + +-- Redis doesn't support Lua table responses: https://stackoverflow.com/a/24302613 +return cjson.encode(results) diff --git a/src/lib.rs b/src/lib.rs index 5c99951..a16c71a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,3 +13,5 @@ pub use geoip; pub use http; #[cfg(feature = "metrics")] pub use metrics; +#[cfg(feature = "rate_limit")] +pub use rate_limit;