Skip to content

Commit

Permalink
feat: implement rate-limiting middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
geekbrother committed Apr 1, 2024
1 parent d9052b4 commit 6837b66
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 27 deletions.
87 changes: 63 additions & 24 deletions src/handlers/mod.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
use {
crate::{state::AppState, utils::network},
axum::extract::{ConnectInfo, MatchedPath, State},
hyper::HeaderMap,
crate::{error::RpcError, state::AppState, utils::network},
axum::{
extract::{ConnectInfo, MatchedPath, State},
http::Request,
middleware::Next,
response::{IntoResponse, Response},
},
serde::{Deserialize, Serialize},
std::{net::SocketAddr, sync::Arc},
wc::{metrics::TaskMetrics, rate_limit::RateLimitExceeded},
tracing::error,
wc::metrics::TaskMetrics,
};

pub mod balance;
Expand Down Expand Up @@ -37,24 +42,58 @@ pub struct SuccessResponse {
status: String,
}

/// Checking rate limit for the request in the handler
pub async fn handle_rate_limit(
state: State<Arc<AppState>>,
headers: HeaderMap,
connect_info: ConnectInfo<SocketAddr>,
path: MatchedPath,
project_id: Option<&str>,
) -> Result<(), RateLimitExceeded> {
state
.rate_limit
.as_ref()
.unwrap()
.is_rate_limited(
path.as_str(),
&network::get_forwarded_ip(headers.clone())
.unwrap_or_else(|| connect_info.0.ip())
.to_string(),
project_id,
)
.await
/// Rate limit middleware that uses `rate_limiting`` token bucket sub crate
/// from the `utils-rs`. IP address and matched path are used as the token key.
pub async fn rate_limit_middleware<B>(
State(state): State<Arc<AppState>>,
req: Request<B>,
next: Next<B>,
) -> Response {
let headers = req.headers().clone();
let connect_info = match req.extensions().get::<ConnectInfo<SocketAddr>>().cloned() {
Some(info) => info,
None => {
error!("Failed to get connect info from request in rate limit middleware");
return next.run(req).await;
}
};
let ip = &network::get_forwarded_ip(headers.clone())
.unwrap_or_else(|| {
error!(
"Failed to get forwarded IP from request in rate limit middleware. Using the \
connect info IP address."
);
connect_info.0.ip()
})
.to_string();
let path = match req.extensions().get::<MatchedPath>().cloned() {
Some(path) => path,
None => {
error!("Failed to get matched path from request in rate limit middleware");
return next.run(req).await;
}
};
// TODO: Get the project ID from the request path and add analytics for the
// rate-limited requests for project ID.
let project_id = None;

let rate_limit = match state.rate_limit.as_ref() {
Some(rate_limit) => rate_limit,
None => {
error!(
"Rate limiting is not enabled in the state, but called in the rate limit \
middleware"
);
return next.run(req).await;
}
};

let is_rate_limited_result = rate_limit
.is_rate_limited(path.as_str(), ip, project_id)
.await;

match is_rate_limited_result {
Ok(_) => next.run(req).await,
Err(e) => RpcError::from(e).into_response(),
}
}
17 changes: 14 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use {
crate::{
env::Config,
handlers::identity::IdentityResponse,
handlers::{identity::IdentityResponse, rate_limit_middleware},
metrics::Metrics,
project::Registry,
providers::ProvidersConfig,
Expand All @@ -12,6 +12,7 @@ use {
aws_sdk_s3::{config::Region, Client as S3Client},
axum::{
extract::connect_info::IntoMakeServiceWithConnectInfo,
middleware,
response::Response,
routing::{get, post},
Router,
Expand Down Expand Up @@ -319,18 +320,28 @@ pub async fn bootstrap(config: Config) -> RpcResult<()> {
"/v1/convert/build-transaction",
post(handlers::convert::transaction::handler),
)
.route_layer(tracing_and_metrics_layer)
.route("/health", get(handlers::health::handler))
.route_layer(tracing_and_metrics_layer)
.layer(cors);

let app = if let Some(geoblock) = geoblock {
app.layer(geoblock)
} else {
app
};
let app = if state_arc.rate_limit.is_some() {
app.route_layer(middleware::from_fn_with_state(
state_arc.clone(),
rate_limit_middleware,
))
} else {
app
};

let app = app.with_state(state_arc.clone());

info!("v{}", build_version);
info!("Running RPC Proxy on port {}", port);
info!("Running Blockchain-API server on port {}", port);
let addr: SocketAddr = format!("{host}:{port}")
.parse()
.expect("Invalid socket address");
Expand Down

0 comments on commit 6837b66

Please sign in to comment.