-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: rate_limiting
token bucket sub module
#14
Merged
Merged
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
b0889c1
feat: initial rate_limiting library
geekbrother c9bde58
feat: adding Moka caching for a single key calls
geekbrother 729ec96
chore: updating tests to pass millis, multiple keys, reffil time check
geekbrother c8af646
chore: passing DateTime type for millis instead of u64
geekbrother 85530b0
feat: optimizing and minifying lua script
geekbrother b9ecde9
revert: optimizing and minifying lua script
geekbrother 9befeaa
feat: bumping redis and deadpool-redis versions
geekbrother File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
[package] | ||
name = "rate_limit" | ||
version = "0.1.0" | ||
edition = "2021" | ||
|
||
[dependencies] | ||
chrono = { version = "0.4", features = ["serde"] } | ||
deadpool-redis = "0.14" | ||
moka = { version = "0.12", features = ["future"] } | ||
redis = { version = "0.24", default-features = false, features = ["script"] } | ||
serde = { version = "1.0", features = ["derive"] } | ||
serde_json = "1.0" | ||
thiserror = "1.0" | ||
tracing = "0.1" | ||
|
||
[dev-dependencies] | ||
futures = "0.3" | ||
tokio = { version = "1", features = ["full"] } | ||
uuid = "1.8" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
version: "3.9" | ||
services: | ||
redis: | ||
image: redis:7.0 | ||
ports: | ||
- "6379:6379" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,279 @@ | ||
use { | ||
chrono::{DateTime, Duration, Utc}, | ||
deadpool_redis::{Pool, PoolError}, | ||
moka::future::Cache, | ||
redis::{RedisError, Script}, | ||
std::{collections::HashMap, sync::Arc}, | ||
}; | ||
|
||
#[derive(Debug, thiserror::Error)] | ||
#[error("Rate limit exceeded. Try again at {reset}")] | ||
pub struct RateLimitExceeded { | ||
reset: u64, | ||
} | ||
|
||
#[derive(Debug, thiserror::Error)] | ||
pub enum InternalRateLimitError { | ||
#[error("Redis pool error {0}")] | ||
Pool(PoolError), | ||
|
||
#[error("Redis error: {0}")] | ||
Redis(RedisError), | ||
} | ||
|
||
#[derive(Debug, thiserror::Error)] | ||
pub enum RateLimitError { | ||
#[error(transparent)] | ||
RateLimitExceeded(RateLimitExceeded), | ||
|
||
#[error("Internal error: {0}")] | ||
Internal(InternalRateLimitError), | ||
} | ||
|
||
/// Rate limit check using a token bucket algorithm for one key and in-memory | ||
/// cache for rate-limited keys. `mem_cache` TTL must be set to the same value | ||
/// as the refill interval. | ||
pub async fn token_bucket( | ||
mem_cache: &Cache<String, u64>, | ||
redis_write_pool: &Arc<Pool>, | ||
key: String, | ||
max_tokens: u32, | ||
interval: Duration, | ||
refill_rate: u32, | ||
now_millis: DateTime<Utc>, | ||
) -> Result<(), RateLimitError> { | ||
// Check if the key is in the memory cache of rate limited keys | ||
// to omit the redis RTT in case of flood | ||
if let Some(reset) = mem_cache.get(&key).await { | ||
return Err(RateLimitError::RateLimitExceeded(RateLimitExceeded { | ||
reset, | ||
})); | ||
} | ||
|
||
let result = token_bucket_many( | ||
redis_write_pool, | ||
vec![key.clone()], | ||
max_tokens, | ||
interval, | ||
refill_rate, | ||
now_millis, | ||
) | ||
.await | ||
.map_err(RateLimitError::Internal)?; | ||
|
||
let (remaining, reset) = result.get(&key).expect("Should contain the key"); | ||
if remaining.is_negative() { | ||
let reset_interval = reset / 1000; | ||
|
||
// Insert the rate-limited key into the memory cache to avoid the redis RTT in | ||
// case of flood | ||
mem_cache.insert(key, reset_interval).await; | ||
|
||
Err(RateLimitError::RateLimitExceeded(RateLimitExceeded { | ||
reset: reset_interval, | ||
})) | ||
} else { | ||
Ok(()) | ||
} | ||
} | ||
|
||
/// Rate limit check using a token bucket algorithm for many keys. | ||
pub async fn token_bucket_many( | ||
redis_write_pool: &Arc<Pool>, | ||
keys: Vec<String>, | ||
max_tokens: u32, | ||
interval: Duration, | ||
refill_rate: u32, | ||
now_millis: DateTime<Utc>, | ||
) -> Result<HashMap<String, (i64, u64)>, InternalRateLimitError> { | ||
// Remaining is number of tokens remaining. -1 for rate limited. | ||
// Reset is the time at which there will be 1 more token than before. This | ||
// could, for example, be used to cache a 0 token count. | ||
Script::new(include_str!("token_bucket.lua")) | ||
.key(keys) | ||
.arg(max_tokens) | ||
.arg(interval.num_milliseconds()) | ||
.arg(refill_rate) | ||
.arg(now_millis.timestamp_millis()) | ||
.invoke_async::<_, String>( | ||
&mut redis_write_pool | ||
.clone() | ||
.get() | ||
.await | ||
.map_err(InternalRateLimitError::Pool)?, | ||
) | ||
.await | ||
.map_err(InternalRateLimitError::Redis) | ||
.map(|value| serde_json::from_str(&value).expect("Redis script should return valid JSON")) | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
const REDIS_URI: &str = "redis://localhost:6379"; | ||
const REFILL_INTERVAL_MILLIS: i64 = 100; | ||
const MAX_TOKENS: u32 = 5; | ||
const REFILL_RATE: u32 = 1; | ||
|
||
use { | ||
super::*, | ||
chrono::Utc, | ||
deadpool_redis::{Config, Runtime}, | ||
redis::AsyncCommands, | ||
tokio::time::sleep, | ||
uuid::Uuid, | ||
}; | ||
|
||
async fn redis_clear_keys(conn_uri: &str, keys: &[String]) { | ||
let client = redis::Client::open(conn_uri).unwrap(); | ||
let mut conn = client.get_async_connection().await.unwrap(); | ||
for key in keys { | ||
let _: () = conn.del(key).await.unwrap(); | ||
} | ||
} | ||
|
||
async fn test_rate_limiting(key: String) { | ||
let cfg = Config::from_url(REDIS_URI); | ||
let pool = Arc::new(cfg.create_pool(Some(Runtime::Tokio1)).unwrap()); | ||
let refill_interval = chrono::Duration::try_milliseconds(REFILL_INTERVAL_MILLIS).unwrap(); | ||
let rate_limit = |now_millis| { | ||
let key = key.clone(); | ||
let pool = pool.clone(); | ||
async move { | ||
token_bucket_many( | ||
&pool, | ||
vec![key.clone()], | ||
MAX_TOKENS, | ||
refill_interval, | ||
REFILL_RATE, | ||
now_millis, | ||
) | ||
.await | ||
.unwrap() | ||
.get(&key) | ||
.unwrap() | ||
.to_owned() | ||
} | ||
}; | ||
// Function to call rate limit multiple times and assert results | ||
// for tokens count and reset timestamp | ||
let call_rate_limit_loop = |loop_iterations| async move { | ||
let first_call_millis = Utc::now(); | ||
for i in 0..=loop_iterations { | ||
let curr_iter = loop_iterations as i64 - i as i64 - 1; | ||
|
||
// Using the first call timestamp for the first call or produce the current | ||
let result = if i == 0 { | ||
rate_limit(first_call_millis).await | ||
} else { | ||
rate_limit(Utc::now()).await | ||
}; | ||
|
||
// Assert the remaining tokens count | ||
assert_eq!(result.0, curr_iter); | ||
// Assert the reset timestamp should be the first call timestamp + refill | ||
// interval | ||
assert_eq!( | ||
result.1, | ||
(first_call_millis.timestamp_millis() + REFILL_INTERVAL_MILLIS) as u64 | ||
); | ||
} | ||
// Returning the refill timestamp | ||
first_call_millis.timestamp_millis() + REFILL_INTERVAL_MILLIS | ||
}; | ||
|
||
// Call rate limit until max tokens limit is reached | ||
call_rate_limit_loop(MAX_TOKENS).await; | ||
|
||
// Sleep for the full refill and try again | ||
// Tokens numbers should be the same as the previous iteration because | ||
// they were fully refilled | ||
sleep((refill_interval * MAX_TOKENS as i32).to_std().unwrap()).await; | ||
let last_timestamp = call_rate_limit_loop(MAX_TOKENS).await; | ||
|
||
// Sleep for just one refill and try again | ||
// The result must contain one token and the reset timestamp should be | ||
// the last full iteration call timestamp + refill interval | ||
sleep((refill_interval).to_std().unwrap()).await; | ||
let result = rate_limit(Utc::now()).await; | ||
assert_eq!(result.0, 0); | ||
assert_eq!(result.1, (last_timestamp + REFILL_INTERVAL_MILLIS) as u64); | ||
} | ||
|
||
#[tokio::test] | ||
async fn test_token_bucket_many() { | ||
const KEYS_NUMBER_TO_TEST: usize = 3; | ||
let keys = (0..KEYS_NUMBER_TO_TEST) | ||
.map(|_| Uuid::new_v4().to_string()) | ||
.collect::<Vec<String>>(); | ||
|
||
// Before running the test, ensure the test keys are cleared | ||
redis_clear_keys(REDIS_URI, &keys).await; | ||
|
||
// Start async test for each key and wait for all to complete | ||
let tasks = keys.iter().map(|key| test_rate_limiting(key.clone())); | ||
futures::future::join_all(tasks).await; | ||
|
||
// Clear keys after the test | ||
redis_clear_keys(REDIS_URI, &keys).await; | ||
} | ||
|
||
#[tokio::test] | ||
async fn test_token_bucket() { | ||
// Create Moka cache with a TTL of the refill interval | ||
let cache: Cache<String, u64> = Cache::builder() | ||
.time_to_live(std::time::Duration::from_millis( | ||
REFILL_INTERVAL_MILLIS as u64, | ||
)) | ||
.build(); | ||
|
||
let cfg = Config::from_url(REDIS_URI); | ||
let pool = Arc::new(cfg.create_pool(Some(Runtime::Tokio1)).unwrap()); | ||
let key = Uuid::new_v4().to_string(); | ||
|
||
// Before running the test, ensure the test keys are cleared | ||
redis_clear_keys(REDIS_URI, &[key.clone()]).await; | ||
|
||
let refill_interval = chrono::Duration::try_milliseconds(REFILL_INTERVAL_MILLIS).unwrap(); | ||
let rate_limit = |now_millis| { | ||
let key = key.clone(); | ||
let pool = pool.clone(); | ||
let cache = cache.clone(); | ||
async move { | ||
token_bucket( | ||
&cache, | ||
&pool, | ||
key.clone(), | ||
MAX_TOKENS, | ||
refill_interval, | ||
REFILL_RATE, | ||
now_millis, | ||
) | ||
.await | ||
} | ||
}; | ||
let call_rate_limit_loop = |now_millis| async move { | ||
for i in 0..=MAX_TOKENS { | ||
let result = rate_limit(now_millis).await; | ||
if i == MAX_TOKENS { | ||
assert!(result | ||
.err() | ||
.unwrap() | ||
.to_string() | ||
.contains("Rate limit exceeded")); | ||
} else { | ||
assert!(result.is_ok()); | ||
} | ||
} | ||
}; | ||
|
||
// Call rate limit until max tokens limit is reached | ||
call_rate_limit_loop(Utc::now()).await; | ||
|
||
// Sleep for refill and try again | ||
sleep((refill_interval * MAX_TOKENS as i32).to_std().unwrap()).await; | ||
call_rate_limit_loop(Utc::now()).await; | ||
|
||
// Clear keys after the test | ||
redis_clear_keys(REDIS_URI, &[key.clone()]).await; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
-- Adapted from https://github.com/upstash/ratelimit/blob/3a8cfb00e827188734ac347965cb743a75fcb98a/src/single.ts#L311 | ||
local keys = KEYS -- identifier including prefixes | ||
local maxTokens = tonumber(ARGV[1]) -- maximum number of tokens | ||
local interval = tonumber(ARGV[2]) -- size of the window in milliseconds | ||
local refillRate = tonumber(ARGV[3]) -- how many tokens are refilled after each interval | ||
local now = tonumber(ARGV[4]) -- current timestamp in milliseconds | ||
|
||
local results = {} | ||
|
||
for i, key in ipairs(keys) do | ||
local bucket = redis.call("HMGET", key, "refilledAt", "tokens") | ||
|
||
local refilledAt | ||
local tokens | ||
|
||
if bucket[1] == false then | ||
refilledAt = now | ||
tokens = maxTokens | ||
else | ||
refilledAt = tonumber(bucket[1]) | ||
tokens = tonumber(bucket[2]) | ||
end | ||
|
||
if now >= refilledAt + interval then | ||
local numRefills = math.floor((now - refilledAt) / interval) | ||
tokens = math.min(maxTokens, tokens + numRefills * refillRate) | ||
|
||
refilledAt = refilledAt + numRefills * interval | ||
end | ||
|
||
if tokens == 0 then | ||
results[key] = {-1, refilledAt + interval} | ||
else | ||
local remaining = tokens - 1 | ||
local expireAt = math.ceil(((maxTokens - remaining) / refillRate)) * interval | ||
|
||
redis.call("HSET", key, "refilledAt", refilledAt, "tokens", remaining) | ||
redis.call("PEXPIRE", key, expireAt) | ||
results[key] = {remaining, refilledAt + interval} | ||
end | ||
end | ||
|
||
-- Redis doesn't support Lua table responses: https://stackoverflow.com/a/24302613 | ||
return cjson.encode(results) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does this mean? Does
Cache
support different keys with different refill intervals e.g. for what Notify Server needs?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The TTL is for all records. If there are requirements to refill different rates for different cases it can be produced with the single refill interval (lowest) and different refill rates by tuning the refill rate to the single refill interval, without adding additional TTL per key.