-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
ad42551
commit b0889c1
Showing
7 changed files
with
242 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
[package] | ||
name = "rate_limit" | ||
version = "0.1.0" | ||
edition = "2021" | ||
|
||
[dependencies] | ||
chrono = { version = "0.4", features = ["serde"] } | ||
deadpool-redis = "0.12" | ||
redis = { version = "0.23", default-features = false, features = ["script"] } | ||
serde = { version = "1.0", features = ["derive"] } | ||
serde_json = "1.0" | ||
thiserror = "1.0" | ||
tracing = "0.1" | ||
|
||
[dev-dependencies] | ||
anyhow = "1" | ||
tokio = { version = "1", features = ["full"] } |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
version: "3.9" | ||
services: | ||
redis: | ||
image: redis:7.0 | ||
ports: | ||
- "6379:6379" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
use { | ||
chrono::{DateTime, Duration, Utc}, | ||
core::fmt, | ||
deadpool_redis::{Pool, PoolError}, | ||
redis::{RedisError, Script}, | ||
std::{collections::HashMap, sync::Arc}, | ||
}; | ||
|
||
pub type Clock = Option<Arc<dyn ClockImpl>>; | ||
pub trait ClockImpl: fmt::Debug + Send + Sync { | ||
fn now(&self) -> DateTime<Utc>; | ||
} | ||
|
||
#[derive(Debug, thiserror::Error)] | ||
#[error("Rate limit exceeded. Try again at {reset}")] | ||
pub struct RateLimitExceeded { | ||
reset: u64, | ||
} | ||
|
||
#[derive(Debug, thiserror::Error)] | ||
pub enum InternalRateLimitError { | ||
#[error("Redis pool error {0}")] | ||
Pool(PoolError), | ||
|
||
#[error("Redis error: {0}")] | ||
Redis(RedisError), | ||
} | ||
|
||
#[derive(Debug, thiserror::Error)] | ||
pub enum RateLimitError { | ||
#[error(transparent)] | ||
RateLimitExceeded(RateLimitExceeded), | ||
|
||
#[error("Internal error: {0}")] | ||
Internal(InternalRateLimitError), | ||
} | ||
|
||
pub async fn token_bucket( | ||
redis_write_pool: &Arc<Pool>, | ||
key: String, | ||
max_tokens: u32, | ||
interval: Duration, | ||
refill_rate: u32, | ||
) -> Result<(), RateLimitError> { | ||
let result = token_bucket_many( | ||
redis_write_pool, | ||
vec![key.clone()], | ||
max_tokens, | ||
interval, | ||
refill_rate, | ||
) | ||
.await | ||
.map_err(RateLimitError::Internal)?; | ||
|
||
let (remaining, reset) = result.get(&key).expect("Should contain the key"); | ||
if remaining.is_negative() { | ||
Err(RateLimitError::RateLimitExceeded(RateLimitExceeded { | ||
reset: reset / 1000, | ||
})) | ||
} else { | ||
Ok(()) | ||
} | ||
} | ||
|
||
pub async fn token_bucket_many( | ||
redis_write_pool: &Arc<Pool>, | ||
keys: Vec<String>, | ||
max_tokens: u32, | ||
interval: Duration, | ||
refill_rate: u32, | ||
) -> Result<HashMap<String, (i64, u64)>, InternalRateLimitError> { | ||
let now = Utc::now(); | ||
|
||
// Remaining is number of tokens remaining. -1 for rate limited. | ||
// Reset is the time at which there will be 1 more token than before. This | ||
// could, for example, be used to cache a 0 token count. | ||
Script::new(include_str!("token_bucket.lua")) | ||
.key(keys) | ||
.arg(max_tokens) | ||
.arg(interval.num_milliseconds()) | ||
.arg(refill_rate) | ||
.arg(now.timestamp_millis()) | ||
.invoke_async::<_, String>( | ||
&mut redis_write_pool | ||
.clone() | ||
.get() | ||
.await | ||
.map_err(InternalRateLimitError::Pool)?, | ||
) | ||
.await | ||
.map_err(InternalRateLimitError::Redis) | ||
.map(|value| serde_json::from_str(&value).expect("Redis script should return valid JSON")) | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
const REDIS_URI: &str = "redis://localhost:6379"; | ||
use { | ||
super::*, | ||
deadpool_redis::{Config, Runtime}, | ||
redis::AsyncCommands, | ||
tokio::time::sleep, | ||
}; | ||
|
||
async fn redis_clear_keys(conn_uri: &str, keys: &[String]) { | ||
let client = redis::Client::open(conn_uri).unwrap(); | ||
let mut conn = client.get_async_connection().await.unwrap(); | ||
for key in keys { | ||
let _: () = conn.del(key).await.unwrap(); | ||
} | ||
} | ||
|
||
#[tokio::test] | ||
async fn test_token_bucket() { | ||
let cfg = Config::from_url(REDIS_URI); | ||
let pool = Arc::new(cfg.create_pool(Some(Runtime::Tokio1)).unwrap()); | ||
let key = "test_token_bucket".to_string(); | ||
|
||
// Before running the test, ensure the test keys are cleared | ||
redis_clear_keys(REDIS_URI, &[key.clone()]).await; | ||
|
||
let max_tokens = 10; | ||
let refill_interval = chrono::Duration::try_milliseconds(100).unwrap(); | ||
let refill_rate = 1; | ||
let rate_limit = || async { | ||
token_bucket_many( | ||
&pool, | ||
vec![key.clone()], | ||
max_tokens, | ||
refill_interval, | ||
refill_rate, | ||
) | ||
.await | ||
.unwrap() | ||
.get(&key.clone()) | ||
.unwrap() | ||
.to_owned() | ||
}; | ||
|
||
// Iterate over the max tokens | ||
for i in 0..=max_tokens { | ||
let curr_iter = max_tokens as i64 - i as i64 - 1; | ||
let result = rate_limit().await; | ||
assert_eq!(result.0, curr_iter); | ||
} | ||
|
||
// Sleep for refill and try again | ||
// Tokens numbers should be the same as the previous iteration | ||
sleep((refill_interval * max_tokens as i32).to_std().unwrap()).await; | ||
|
||
for i in 0..=max_tokens { | ||
let curr_iter = max_tokens as i64 - i as i64 - 1; | ||
let result = rate_limit().await; | ||
assert_eq!(result.0, curr_iter); | ||
} | ||
|
||
// Clear keys after the test | ||
redis_clear_keys(REDIS_URI, &[key.clone()]).await; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
-- Adapted from https://github.com/upstash/ratelimit/blob/3a8cfb00e827188734ac347965cb743a75fcb98a/src/single.ts#L311 | ||
local keys = KEYS -- identifier including prefixes | ||
local maxTokens = tonumber(ARGV[1]) -- maximum number of tokens | ||
local interval = tonumber(ARGV[2]) -- size of the window in milliseconds | ||
local refillRate = tonumber(ARGV[3]) -- how many tokens are refilled after each interval | ||
local now = tonumber(ARGV[4]) -- current timestamp in milliseconds | ||
|
||
local results = {} | ||
|
||
for i, key in ipairs(keys) do | ||
local bucket = redis.call("HMGET", key, "refilledAt", "tokens") | ||
|
||
local refilledAt | ||
local tokens | ||
|
||
if bucket[1] == false then | ||
refilledAt = now | ||
tokens = maxTokens | ||
else | ||
refilledAt = tonumber(bucket[1]) | ||
tokens = tonumber(bucket[2]) | ||
end | ||
|
||
if now >= refilledAt + interval then | ||
local numRefills = math.floor((now - refilledAt) / interval) | ||
tokens = math.min(maxTokens, tokens + numRefills * refillRate) | ||
|
||
refilledAt = refilledAt + numRefills * interval | ||
end | ||
|
||
if tokens == 0 then | ||
results[key] = {-1, refilledAt + interval} | ||
else | ||
local remaining = tokens - 1 | ||
local expireAt = math.ceil(((maxTokens - remaining) / refillRate)) * interval | ||
|
||
redis.call("HSET", key, "refilledAt", refilledAt, "tokens", remaining) | ||
redis.call("PEXPIRE", key, expireAt) | ||
results[key] = {remaining, refilledAt + interval} | ||
end | ||
end | ||
|
||
-- Redis doesn't support Lua table responses: https://stackoverflow.com/a/24302613 | ||
return cjson.encode(results) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters