Skip to content

Commit

Permalink
feat: rate_limiting token bucket sub module (#14)
Browse files Browse the repository at this point in the history
  • Loading branch information
geekbrother authored Apr 1, 2024
1 parent ad42551 commit e2bc55d
Show file tree
Hide file tree
Showing 7 changed files with 363 additions and 0 deletions.
10 changes: 10 additions & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,16 @@ jobs:
rust: stable
env:
RUST_BACKTRACE: full
services:
redis:
image: redis:7.2-alpine
options: >-
--health-cmd "redis-cli ping"
--health-interval 10s
--health-timeout 5s
--health-retries 5
ports:
- 6379:6379
steps:
- uses: actions/checkout@v3

Expand Down
3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ full = [
"geoip",
"http",
"metrics",
"rate_limit",
]
alloc = ["dep:alloc"]
analytics = ["dep:analytics"]
Expand All @@ -33,6 +34,7 @@ geoip = ["dep:geoip"]
http = []
metrics = ["dep:metrics", "future/metrics", "alloc/metrics", "http/metrics"]
profiler = ["alloc/profiler"]
rate_limit = ["dep:rate_limit"]

[workspace.dependencies]
aws-sdk-s3 = "1.13"
Expand All @@ -45,6 +47,7 @@ future = { path = "./crates/future", optional = true }
geoip = { path = "./crates/geoip", optional = true }
http = { path = "./crates/http", optional = true }
metrics = { path = "./crates/metrics", optional = true }
rate_limit = { path = "./crates/rate_limit", optional = true }

[dev-dependencies]
anyhow = "1"
Expand Down
19 changes: 19 additions & 0 deletions crates/rate_limit/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
[package]
name = "rate_limit"
version = "0.1.0"
edition = "2021"

[dependencies]
chrono = { version = "0.4", features = ["serde"] }
deadpool-redis = "0.14"
moka = { version = "0.12", features = ["future"] }
redis = { version = "0.24", default-features = false, features = ["script"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
thiserror = "1.0"
tracing = "0.1"

[dev-dependencies]
futures = "0.3"
tokio = { version = "1", features = ["full"] }
uuid = "1.8"
6 changes: 6 additions & 0 deletions crates/rate_limit/docker-compose.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
version: "3.9"
services:
redis:
image: redis:7.0
ports:
- "6379:6379"
279 changes: 279 additions & 0 deletions crates/rate_limit/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,279 @@
use {
chrono::{DateTime, Duration, Utc},
deadpool_redis::{Pool, PoolError},
moka::future::Cache,
redis::{RedisError, Script},
std::{collections::HashMap, sync::Arc},
};

#[derive(Debug, thiserror::Error)]
#[error("Rate limit exceeded. Try again at {reset}")]
pub struct RateLimitExceeded {
reset: u64,
}

#[derive(Debug, thiserror::Error)]
pub enum InternalRateLimitError {
#[error("Redis pool error {0}")]
Pool(PoolError),

#[error("Redis error: {0}")]
Redis(RedisError),
}

#[derive(Debug, thiserror::Error)]
pub enum RateLimitError {
#[error(transparent)]
RateLimitExceeded(RateLimitExceeded),

#[error("Internal error: {0}")]
Internal(InternalRateLimitError),
}

/// Rate limit check using a token bucket algorithm for one key and in-memory
/// cache for rate-limited keys. `mem_cache` TTL must be set to the same value
/// as the refill interval.
pub async fn token_bucket(
mem_cache: &Cache<String, u64>,
redis_write_pool: &Arc<Pool>,
key: String,
max_tokens: u32,
interval: Duration,
refill_rate: u32,
now_millis: DateTime<Utc>,
) -> Result<(), RateLimitError> {
// Check if the key is in the memory cache of rate limited keys
// to omit the redis RTT in case of flood
if let Some(reset) = mem_cache.get(&key).await {
return Err(RateLimitError::RateLimitExceeded(RateLimitExceeded {
reset,
}));
}

let result = token_bucket_many(
redis_write_pool,
vec![key.clone()],
max_tokens,
interval,
refill_rate,
now_millis,
)
.await
.map_err(RateLimitError::Internal)?;

let (remaining, reset) = result.get(&key).expect("Should contain the key");
if remaining.is_negative() {
let reset_interval = reset / 1000;

// Insert the rate-limited key into the memory cache to avoid the redis RTT in
// case of flood
mem_cache.insert(key, reset_interval).await;

Err(RateLimitError::RateLimitExceeded(RateLimitExceeded {
reset: reset_interval,
}))
} else {
Ok(())
}
}

/// Rate limit check using a token bucket algorithm for many keys.
pub async fn token_bucket_many(
redis_write_pool: &Arc<Pool>,
keys: Vec<String>,
max_tokens: u32,
interval: Duration,
refill_rate: u32,
now_millis: DateTime<Utc>,
) -> Result<HashMap<String, (i64, u64)>, InternalRateLimitError> {
// Remaining is number of tokens remaining. -1 for rate limited.
// Reset is the time at which there will be 1 more token than before. This
// could, for example, be used to cache a 0 token count.
Script::new(include_str!("token_bucket.lua"))
.key(keys)
.arg(max_tokens)
.arg(interval.num_milliseconds())
.arg(refill_rate)
.arg(now_millis.timestamp_millis())
.invoke_async::<_, String>(
&mut redis_write_pool
.clone()
.get()
.await
.map_err(InternalRateLimitError::Pool)?,
)
.await
.map_err(InternalRateLimitError::Redis)
.map(|value| serde_json::from_str(&value).expect("Redis script should return valid JSON"))
}

#[cfg(test)]
mod tests {
const REDIS_URI: &str = "redis://localhost:6379";
const REFILL_INTERVAL_MILLIS: i64 = 100;
const MAX_TOKENS: u32 = 5;
const REFILL_RATE: u32 = 1;

use {
super::*,
chrono::Utc,
deadpool_redis::{Config, Runtime},
redis::AsyncCommands,
tokio::time::sleep,
uuid::Uuid,
};

async fn redis_clear_keys(conn_uri: &str, keys: &[String]) {
let client = redis::Client::open(conn_uri).unwrap();
let mut conn = client.get_async_connection().await.unwrap();
for key in keys {
let _: () = conn.del(key).await.unwrap();
}
}

async fn test_rate_limiting(key: String) {
let cfg = Config::from_url(REDIS_URI);
let pool = Arc::new(cfg.create_pool(Some(Runtime::Tokio1)).unwrap());
let refill_interval = chrono::Duration::try_milliseconds(REFILL_INTERVAL_MILLIS).unwrap();
let rate_limit = |now_millis| {
let key = key.clone();
let pool = pool.clone();
async move {
token_bucket_many(
&pool,
vec![key.clone()],
MAX_TOKENS,
refill_interval,
REFILL_RATE,
now_millis,
)
.await
.unwrap()
.get(&key)
.unwrap()
.to_owned()
}
};
// Function to call rate limit multiple times and assert results
// for tokens count and reset timestamp
let call_rate_limit_loop = |loop_iterations| async move {
let first_call_millis = Utc::now();
for i in 0..=loop_iterations {
let curr_iter = loop_iterations as i64 - i as i64 - 1;

// Using the first call timestamp for the first call or produce the current
let result = if i == 0 {
rate_limit(first_call_millis).await
} else {
rate_limit(Utc::now()).await
};

// Assert the remaining tokens count
assert_eq!(result.0, curr_iter);
// Assert the reset timestamp should be the first call timestamp + refill
// interval
assert_eq!(
result.1,
(first_call_millis.timestamp_millis() + REFILL_INTERVAL_MILLIS) as u64
);
}
// Returning the refill timestamp
first_call_millis.timestamp_millis() + REFILL_INTERVAL_MILLIS
};

// Call rate limit until max tokens limit is reached
call_rate_limit_loop(MAX_TOKENS).await;

// Sleep for the full refill and try again
// Tokens numbers should be the same as the previous iteration because
// they were fully refilled
sleep((refill_interval * MAX_TOKENS as i32).to_std().unwrap()).await;
let last_timestamp = call_rate_limit_loop(MAX_TOKENS).await;

// Sleep for just one refill and try again
// The result must contain one token and the reset timestamp should be
// the last full iteration call timestamp + refill interval
sleep((refill_interval).to_std().unwrap()).await;
let result = rate_limit(Utc::now()).await;
assert_eq!(result.0, 0);
assert_eq!(result.1, (last_timestamp + REFILL_INTERVAL_MILLIS) as u64);
}

#[tokio::test]
async fn test_token_bucket_many() {
const KEYS_NUMBER_TO_TEST: usize = 3;
let keys = (0..KEYS_NUMBER_TO_TEST)
.map(|_| Uuid::new_v4().to_string())
.collect::<Vec<String>>();

// Before running the test, ensure the test keys are cleared
redis_clear_keys(REDIS_URI, &keys).await;

// Start async test for each key and wait for all to complete
let tasks = keys.iter().map(|key| test_rate_limiting(key.clone()));
futures::future::join_all(tasks).await;

// Clear keys after the test
redis_clear_keys(REDIS_URI, &keys).await;
}

#[tokio::test]
async fn test_token_bucket() {
// Create Moka cache with a TTL of the refill interval
let cache: Cache<String, u64> = Cache::builder()
.time_to_live(std::time::Duration::from_millis(
REFILL_INTERVAL_MILLIS as u64,
))
.build();

let cfg = Config::from_url(REDIS_URI);
let pool = Arc::new(cfg.create_pool(Some(Runtime::Tokio1)).unwrap());
let key = Uuid::new_v4().to_string();

// Before running the test, ensure the test keys are cleared
redis_clear_keys(REDIS_URI, &[key.clone()]).await;

let refill_interval = chrono::Duration::try_milliseconds(REFILL_INTERVAL_MILLIS).unwrap();
let rate_limit = |now_millis| {
let key = key.clone();
let pool = pool.clone();
let cache = cache.clone();
async move {
token_bucket(
&cache,
&pool,
key.clone(),
MAX_TOKENS,
refill_interval,
REFILL_RATE,
now_millis,
)
.await
}
};
let call_rate_limit_loop = |now_millis| async move {
for i in 0..=MAX_TOKENS {
let result = rate_limit(now_millis).await;
if i == MAX_TOKENS {
assert!(result
.err()
.unwrap()
.to_string()
.contains("Rate limit exceeded"));
} else {
assert!(result.is_ok());
}
}
};

// Call rate limit until max tokens limit is reached
call_rate_limit_loop(Utc::now()).await;

// Sleep for refill and try again
sleep((refill_interval * MAX_TOKENS as i32).to_std().unwrap()).await;
call_rate_limit_loop(Utc::now()).await;

// Clear keys after the test
redis_clear_keys(REDIS_URI, &[key.clone()]).await;
}
}
44 changes: 44 additions & 0 deletions crates/rate_limit/src/token_bucket.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
-- Adapted from https://github.com/upstash/ratelimit/blob/3a8cfb00e827188734ac347965cb743a75fcb98a/src/single.ts#L311
local keys = KEYS -- identifier including prefixes
local maxTokens = tonumber(ARGV[1]) -- maximum number of tokens
local interval = tonumber(ARGV[2]) -- size of the window in milliseconds
local refillRate = tonumber(ARGV[3]) -- how many tokens are refilled after each interval
local now = tonumber(ARGV[4]) -- current timestamp in milliseconds

local results = {}

for i, key in ipairs(keys) do
local bucket = redis.call("HMGET", key, "refilledAt", "tokens")

local refilledAt
local tokens

if bucket[1] == false then
refilledAt = now
tokens = maxTokens
else
refilledAt = tonumber(bucket[1])
tokens = tonumber(bucket[2])
end

if now >= refilledAt + interval then
local numRefills = math.floor((now - refilledAt) / interval)
tokens = math.min(maxTokens, tokens + numRefills * refillRate)

refilledAt = refilledAt + numRefills * interval
end

if tokens == 0 then
results[key] = {-1, refilledAt + interval}
else
local remaining = tokens - 1
local expireAt = math.ceil(((maxTokens - remaining) / refillRate)) * interval

redis.call("HSET", key, "refilledAt", refilledAt, "tokens", remaining)
redis.call("PEXPIRE", key, expireAt)
results[key] = {remaining, refilledAt + interval}
end
end

-- Redis doesn't support Lua table responses: https://stackoverflow.com/a/24302613
return cjson.encode(results)
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,5 @@ pub use geoip;
pub use http;
#[cfg(feature = "metrics")]
pub use metrics;
#[cfg(feature = "rate_limit")]
pub use rate_limit;

0 comments on commit e2bc55d

Please sign in to comment.