From 6837b6658147604be15e756dc25997db11124bc5 Mon Sep 17 00:00:00 2001 From: Max Kalashnikoff <geekmaks@gmail.com> Date: Mon, 1 Apr 2024 11:25:53 +0200 Subject: [PATCH] feat: implement rate-limiting middleware --- src/handlers/mod.rs | 87 ++++++++++++++++++++++++++++++++------------- src/lib.rs | 17 +++++++-- 2 files changed, 77 insertions(+), 27 deletions(-) diff --git a/src/handlers/mod.rs b/src/handlers/mod.rs index 3b682667ee..5a83f3633a 100644 --- a/src/handlers/mod.rs +++ b/src/handlers/mod.rs @@ -1,10 +1,15 @@ use { - crate::{state::AppState, utils::network}, - axum::extract::{ConnectInfo, MatchedPath, State}, - hyper::HeaderMap, + crate::{error::RpcError, state::AppState, utils::network}, + axum::{ + extract::{ConnectInfo, MatchedPath, State}, + http::Request, + middleware::Next, + response::{IntoResponse, Response}, + }, serde::{Deserialize, Serialize}, std::{net::SocketAddr, sync::Arc}, - wc::{metrics::TaskMetrics, rate_limit::RateLimitExceeded}, + tracing::error, + wc::metrics::TaskMetrics, }; pub mod balance; @@ -37,24 +42,58 @@ pub struct SuccessResponse { status: String, } -/// Checking rate limit for the request in the handler -pub async fn handle_rate_limit( - state: State<Arc<AppState>>, - headers: HeaderMap, - connect_info: ConnectInfo<SocketAddr>, - path: MatchedPath, - project_id: Option<&str>, -) -> Result<(), RateLimitExceeded> { - state - .rate_limit - .as_ref() - .unwrap() - .is_rate_limited( - path.as_str(), - &network::get_forwarded_ip(headers.clone()) - .unwrap_or_else(|| connect_info.0.ip()) - .to_string(), - project_id, - ) - .await +/// Rate limit middleware that uses `rate_limiting`` token bucket sub crate +/// from the `utils-rs`. IP address and matched path are used as the token key. +pub async fn rate_limit_middleware<B>( + State(state): State<Arc<AppState>>, + req: Request<B>, + next: Next<B>, +) -> Response { + let headers = req.headers().clone(); + let connect_info = match req.extensions().get::<ConnectInfo<SocketAddr>>().cloned() { + Some(info) => info, + None => { + error!("Failed to get connect info from request in rate limit middleware"); + return next.run(req).await; + } + }; + let ip = &network::get_forwarded_ip(headers.clone()) + .unwrap_or_else(|| { + error!( + "Failed to get forwarded IP from request in rate limit middleware. Using the \ + connect info IP address." + ); + connect_info.0.ip() + }) + .to_string(); + let path = match req.extensions().get::<MatchedPath>().cloned() { + Some(path) => path, + None => { + error!("Failed to get matched path from request in rate limit middleware"); + return next.run(req).await; + } + }; + // TODO: Get the project ID from the request path and add analytics for the + // rate-limited requests for project ID. + let project_id = None; + + let rate_limit = match state.rate_limit.as_ref() { + Some(rate_limit) => rate_limit, + None => { + error!( + "Rate limiting is not enabled in the state, but called in the rate limit \ + middleware" + ); + return next.run(req).await; + } + }; + + let is_rate_limited_result = rate_limit + .is_rate_limited(path.as_str(), ip, project_id) + .await; + + match is_rate_limited_result { + Ok(_) => next.run(req).await, + Err(e) => RpcError::from(e).into_response(), + } } diff --git a/src/lib.rs b/src/lib.rs index bface4a911..788f0a0bc8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,7 +1,7 @@ use { crate::{ env::Config, - handlers::identity::IdentityResponse, + handlers::{identity::IdentityResponse, rate_limit_middleware}, metrics::Metrics, project::Registry, providers::ProvidersConfig, @@ -12,6 +12,7 @@ use { aws_sdk_s3::{config::Region, Client as S3Client}, axum::{ extract::connect_info::IntoMakeServiceWithConnectInfo, + middleware, response::Response, routing::{get, post}, Router, @@ -319,18 +320,28 @@ pub async fn bootstrap(config: Config) -> RpcResult<()> { "/v1/convert/build-transaction", post(handlers::convert::transaction::handler), ) - .route_layer(tracing_and_metrics_layer) .route("/health", get(handlers::health::handler)) + .route_layer(tracing_and_metrics_layer) .layer(cors); + let app = if let Some(geoblock) = geoblock { app.layer(geoblock) } else { app }; + let app = if state_arc.rate_limit.is_some() { + app.route_layer(middleware::from_fn_with_state( + state_arc.clone(), + rate_limit_middleware, + )) + } else { + app + }; + let app = app.with_state(state_arc.clone()); info!("v{}", build_version); - info!("Running RPC Proxy on port {}", port); + info!("Running Blockchain-API server on port {}", port); let addr: SocketAddr = format!("{host}:{port}") .parse() .expect("Invalid socket address");