Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: rate_limiting token bucket sub module #14

Merged
merged 7 commits into from
Apr 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,16 @@ jobs:
rust: stable
env:
RUST_BACKTRACE: full
services:
redis:
image: redis:7.2-alpine
options: >-
--health-cmd "redis-cli ping"
--health-interval 10s
--health-timeout 5s
--health-retries 5
ports:
- 6379:6379
steps:
- uses: actions/checkout@v3

Expand Down
3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ full = [
"geoip",
"http",
"metrics",
"rate_limit",
]
alloc = ["dep:alloc"]
analytics = ["dep:analytics"]
Expand All @@ -33,6 +34,7 @@ geoip = ["dep:geoip"]
http = []
metrics = ["dep:metrics", "future/metrics", "alloc/metrics", "http/metrics"]
profiler = ["alloc/profiler"]
rate_limit = ["dep:rate_limit"]

[workspace.dependencies]
aws-sdk-s3 = "1.13"
Expand All @@ -45,6 +47,7 @@ future = { path = "./crates/future", optional = true }
geoip = { path = "./crates/geoip", optional = true }
http = { path = "./crates/http", optional = true }
metrics = { path = "./crates/metrics", optional = true }
rate_limit = { path = "./crates/rate_limit", optional = true }

[dev-dependencies]
anyhow = "1"
Expand Down
19 changes: 19 additions & 0 deletions crates/rate_limit/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
[package]
name = "rate_limit"
version = "0.1.0"
edition = "2021"

[dependencies]
chrono = { version = "0.4", features = ["serde"] }
deadpool-redis = "0.14"
moka = { version = "0.12", features = ["future"] }
redis = { version = "0.24", default-features = false, features = ["script"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
thiserror = "1.0"
tracing = "0.1"

[dev-dependencies]
futures = "0.3"
tokio = { version = "1", features = ["full"] }
uuid = "1.8"
6 changes: 6 additions & 0 deletions crates/rate_limit/docker-compose.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
version: "3.9"
services:
redis:
image: redis:7.0
ports:
- "6379:6379"
279 changes: 279 additions & 0 deletions crates/rate_limit/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,279 @@
use {
chrono::{DateTime, Duration, Utc},
deadpool_redis::{Pool, PoolError},
moka::future::Cache,
redis::{RedisError, Script},
std::{collections::HashMap, sync::Arc},
};

#[derive(Debug, thiserror::Error)]
#[error("Rate limit exceeded. Try again at {reset}")]
pub struct RateLimitExceeded {
reset: u64,
}

#[derive(Debug, thiserror::Error)]
pub enum InternalRateLimitError {
#[error("Redis pool error {0}")]
Pool(PoolError),

#[error("Redis error: {0}")]
Redis(RedisError),
}

#[derive(Debug, thiserror::Error)]
pub enum RateLimitError {
#[error(transparent)]
RateLimitExceeded(RateLimitExceeded),

#[error("Internal error: {0}")]
Internal(InternalRateLimitError),
}

/// Rate limit check using a token bucket algorithm for one key and in-memory
/// cache for rate-limited keys. `mem_cache` TTL must be set to the same value
/// as the refill interval.
pub async fn token_bucket(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mem_cache TTL must be set to the same value as the refill interval.

What does this mean? Does Cache support different keys with different refill intervals e.g. for what Notify Server needs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The TTL is for all records. If there are requirements to refill different rates for different cases it can be produced with the single refill interval (lowest) and different refill rates by tuning the refill rate to the single refill interval, without adding additional TTL per key.

mem_cache: &Cache<String, u64>,
redis_write_pool: &Arc<Pool>,
key: String,
max_tokens: u32,
interval: Duration,
refill_rate: u32,
now_millis: DateTime<Utc>,
) -> Result<(), RateLimitError> {
// Check if the key is in the memory cache of rate limited keys
// to omit the redis RTT in case of flood
if let Some(reset) = mem_cache.get(&key).await {
return Err(RateLimitError::RateLimitExceeded(RateLimitExceeded {
reset,
}));
}

let result = token_bucket_many(
redis_write_pool,
vec![key.clone()],
max_tokens,
interval,
refill_rate,
now_millis,
)
.await
.map_err(RateLimitError::Internal)?;

let (remaining, reset) = result.get(&key).expect("Should contain the key");
if remaining.is_negative() {
let reset_interval = reset / 1000;

// Insert the rate-limited key into the memory cache to avoid the redis RTT in
// case of flood
mem_cache.insert(key, reset_interval).await;

Err(RateLimitError::RateLimitExceeded(RateLimitExceeded {
reset: reset_interval,
}))
} else {
Ok(())
}
}

/// Rate limit check using a token bucket algorithm for many keys.
pub async fn token_bucket_many(
redis_write_pool: &Arc<Pool>,
keys: Vec<String>,
max_tokens: u32,
interval: Duration,
refill_rate: u32,
now_millis: DateTime<Utc>,
) -> Result<HashMap<String, (i64, u64)>, InternalRateLimitError> {
// Remaining is number of tokens remaining. -1 for rate limited.
// Reset is the time at which there will be 1 more token than before. This
// could, for example, be used to cache a 0 token count.
Script::new(include_str!("token_bucket.lua"))
.key(keys)
.arg(max_tokens)
.arg(interval.num_milliseconds())
.arg(refill_rate)
.arg(now_millis.timestamp_millis())
.invoke_async::<_, String>(
&mut redis_write_pool
.clone()
.get()
.await
.map_err(InternalRateLimitError::Pool)?,
)
.await
.map_err(InternalRateLimitError::Redis)
.map(|value| serde_json::from_str(&value).expect("Redis script should return valid JSON"))
}

#[cfg(test)]
mod tests {
const REDIS_URI: &str = "redis://localhost:6379";
const REFILL_INTERVAL_MILLIS: i64 = 100;
const MAX_TOKENS: u32 = 5;
const REFILL_RATE: u32 = 1;

use {
super::*,
chrono::Utc,
deadpool_redis::{Config, Runtime},
redis::AsyncCommands,
tokio::time::sleep,
uuid::Uuid,
};

async fn redis_clear_keys(conn_uri: &str, keys: &[String]) {
let client = redis::Client::open(conn_uri).unwrap();
let mut conn = client.get_async_connection().await.unwrap();
for key in keys {
let _: () = conn.del(key).await.unwrap();
}
}

async fn test_rate_limiting(key: String) {
let cfg = Config::from_url(REDIS_URI);
let pool = Arc::new(cfg.create_pool(Some(Runtime::Tokio1)).unwrap());
let refill_interval = chrono::Duration::try_milliseconds(REFILL_INTERVAL_MILLIS).unwrap();
let rate_limit = |now_millis| {
let key = key.clone();
let pool = pool.clone();
async move {
token_bucket_many(
&pool,
vec![key.clone()],
MAX_TOKENS,
refill_interval,
REFILL_RATE,
now_millis,
)
.await
.unwrap()
.get(&key)
.unwrap()
.to_owned()
}
};
// Function to call rate limit multiple times and assert results
// for tokens count and reset timestamp
let call_rate_limit_loop = |loop_iterations| async move {
let first_call_millis = Utc::now();
for i in 0..=loop_iterations {
let curr_iter = loop_iterations as i64 - i as i64 - 1;

// Using the first call timestamp for the first call or produce the current
let result = if i == 0 {
rate_limit(first_call_millis).await
} else {
rate_limit(Utc::now()).await
};

// Assert the remaining tokens count
assert_eq!(result.0, curr_iter);
// Assert the reset timestamp should be the first call timestamp + refill
// interval
assert_eq!(
result.1,
(first_call_millis.timestamp_millis() + REFILL_INTERVAL_MILLIS) as u64
);
}
// Returning the refill timestamp
first_call_millis.timestamp_millis() + REFILL_INTERVAL_MILLIS
};

// Call rate limit until max tokens limit is reached
call_rate_limit_loop(MAX_TOKENS).await;

// Sleep for the full refill and try again
// Tokens numbers should be the same as the previous iteration because
// they were fully refilled
sleep((refill_interval * MAX_TOKENS as i32).to_std().unwrap()).await;
let last_timestamp = call_rate_limit_loop(MAX_TOKENS).await;

// Sleep for just one refill and try again
// The result must contain one token and the reset timestamp should be
// the last full iteration call timestamp + refill interval
sleep((refill_interval).to_std().unwrap()).await;
let result = rate_limit(Utc::now()).await;
assert_eq!(result.0, 0);
assert_eq!(result.1, (last_timestamp + REFILL_INTERVAL_MILLIS) as u64);
}

#[tokio::test]
async fn test_token_bucket_many() {
const KEYS_NUMBER_TO_TEST: usize = 3;
let keys = (0..KEYS_NUMBER_TO_TEST)
.map(|_| Uuid::new_v4().to_string())
.collect::<Vec<String>>();

// Before running the test, ensure the test keys are cleared
redis_clear_keys(REDIS_URI, &keys).await;

// Start async test for each key and wait for all to complete
let tasks = keys.iter().map(|key| test_rate_limiting(key.clone()));
futures::future::join_all(tasks).await;

// Clear keys after the test
redis_clear_keys(REDIS_URI, &keys).await;
}

#[tokio::test]
async fn test_token_bucket() {
// Create Moka cache with a TTL of the refill interval
let cache: Cache<String, u64> = Cache::builder()
.time_to_live(std::time::Duration::from_millis(
REFILL_INTERVAL_MILLIS as u64,
))
.build();

let cfg = Config::from_url(REDIS_URI);
let pool = Arc::new(cfg.create_pool(Some(Runtime::Tokio1)).unwrap());
let key = Uuid::new_v4().to_string();

// Before running the test, ensure the test keys are cleared
redis_clear_keys(REDIS_URI, &[key.clone()]).await;

let refill_interval = chrono::Duration::try_milliseconds(REFILL_INTERVAL_MILLIS).unwrap();
let rate_limit = |now_millis| {
let key = key.clone();
let pool = pool.clone();
let cache = cache.clone();
async move {
token_bucket(
&cache,
&pool,
key.clone(),
MAX_TOKENS,
refill_interval,
REFILL_RATE,
now_millis,
)
.await
}
};
let call_rate_limit_loop = |now_millis| async move {
for i in 0..=MAX_TOKENS {
let result = rate_limit(now_millis).await;
if i == MAX_TOKENS {
assert!(result
.err()
.unwrap()
.to_string()
.contains("Rate limit exceeded"));
} else {
assert!(result.is_ok());
}
}
};

// Call rate limit until max tokens limit is reached
call_rate_limit_loop(Utc::now()).await;

// Sleep for refill and try again
sleep((refill_interval * MAX_TOKENS as i32).to_std().unwrap()).await;
call_rate_limit_loop(Utc::now()).await;

// Clear keys after the test
redis_clear_keys(REDIS_URI, &[key.clone()]).await;
}
}
44 changes: 44 additions & 0 deletions crates/rate_limit/src/token_bucket.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
-- Adapted from https://github.com/upstash/ratelimit/blob/3a8cfb00e827188734ac347965cb743a75fcb98a/src/single.ts#L311
local keys = KEYS -- identifier including prefixes
local maxTokens = tonumber(ARGV[1]) -- maximum number of tokens
local interval = tonumber(ARGV[2]) -- size of the window in milliseconds
local refillRate = tonumber(ARGV[3]) -- how many tokens are refilled after each interval
local now = tonumber(ARGV[4]) -- current timestamp in milliseconds

local results = {}

for i, key in ipairs(keys) do
local bucket = redis.call("HMGET", key, "refilledAt", "tokens")

local refilledAt
local tokens

if bucket[1] == false then
refilledAt = now
tokens = maxTokens
else
refilledAt = tonumber(bucket[1])
tokens = tonumber(bucket[2])
end

if now >= refilledAt + interval then
local numRefills = math.floor((now - refilledAt) / interval)
tokens = math.min(maxTokens, tokens + numRefills * refillRate)

refilledAt = refilledAt + numRefills * interval
end

if tokens == 0 then
results[key] = {-1, refilledAt + interval}
else
local remaining = tokens - 1
local expireAt = math.ceil(((maxTokens - remaining) / refillRate)) * interval

redis.call("HSET", key, "refilledAt", refilledAt, "tokens", remaining)
redis.call("PEXPIRE", key, expireAt)
results[key] = {remaining, refilledAt + interval}
end
end

-- Redis doesn't support Lua table responses: https://stackoverflow.com/a/24302613
return cjson.encode(results)
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,5 @@ pub use geoip;
pub use http;
#[cfg(feature = "metrics")]
pub use metrics;
#[cfg(feature = "rate_limit")]
pub use rate_limit;
Loading