Skip to content

Commit

Permalink
feat: add geoblock tower middleware (#4)
Browse files Browse the repository at this point in the history
  • Loading branch information
xav authored Sep 26, 2023
1 parent 99e610e commit 03ccdfd
Show file tree
Hide file tree
Showing 7 changed files with 432 additions and 0 deletions.
10 changes: 10 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,25 +20,31 @@ full = [
"future",
"http",
"metrics",
"geoblock"
]
alloc = ["dep:alloc"]
collections = ["dep:collections"]
future = ["dep:future"]
http = []
metrics = ["dep:metrics", "future/metrics", "alloc/metrics", "http/metrics"]
profiler = ["alloc/profiler"]
geoblock = ["dep:geoblock"]

[dependencies]
alloc = { path = "./crates/alloc", optional = true }
collections = { path = "./crates/collections", optional = true }
future = { path = "./crates/future", optional = true }
http = { path = "./crates/http", optional = true }
metrics = { path = "./crates/metrics", optional = true }
geoblock = { path = "./crates/geoblock", optional = true }

[dev-dependencies]
anyhow = "1"
structopt = { version = "0.3", default-features = false }
tokio = { version = "1", features = ["full"] }
hyper = { version = "0.14", features = ["full"] }
tower = { version = "0.4", features = ["util", "filter"] }
axum = "0.6.1"

[[example]]
name = "alloc_profiler"
Expand All @@ -51,3 +57,7 @@ required-features = ["alloc", "metrics"]
[[example]]
name = "metrics"
required-features = ["metrics", "future"]

[[example]]
name = "geoblock"
required-features = ["geoblock"]
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,16 @@ Metrics and other utils for HTTP servers.

Global service metrics. Currently based on `opentelemetry` SDK and exported in `prometheus` format.

## `geoblock`

Tower middleware for blocking requests based on clients' IP origin.

Note: this middleware requires you to use
[Router::into_make_service_with_connect_info](https://docs.rs/axum/latest/axum/struct.Router.html#method.into_make_service_with_connect_info)
to run your app otherwise it will fail at runtime.

See [Router::into_make_service_with_connect_info](https://docs.rs/axum/latest/axum/struct.Router.html#method.into_make_service_with_connect_info) for more details.

## Examples

- [Metrics integration](examples/metrics.rs). Prints service metrics in the default (`prometheus`) format.
Expand Down
24 changes: 24 additions & 0 deletions crates/geoblock/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
[package]
name = "geoblock"
version = "0.1.0"
edition = "2021"

[features]
default = []

[dependencies]
tracing = "0.1"
axum = "0.6"
thiserror = "1.0"
hyper = "0.14"
tower = "0.4"
tower-layer = "0.3"
pin-project = "1"
futures-core = "0.3"
http-body = "0.4"
axum-client-ip = "0.4"
geoip = { git = "https://github.com/WalletConnect/geoip.git", tag = "0.2.0"}

[dev-dependencies]
tokio = { version = "1", features = ["full"] }
axum = "0.6"
264 changes: 264 additions & 0 deletions crates/geoblock/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,264 @@
//! Middleware which adds geo-location IP blocking.
//!
//! Note: this middleware requires you to use
//! [Router::into_make_service_with_connect_info](https://docs.rs/axum/latest/axum/struct.Router.html#method.into_make_service_with_connect_info)
//! to run your app otherwise it will fail at runtime.
//!
//! See [Router::into_make_service_with_connect_info](https://docs.rs/axum/latest/axum/struct.Router.html#method.into_make_service_with_connect_info) for more details.
pub use geoip;
use {
axum_client_ip::InsecureClientIp,
geoip::GeoIpResolver,
http_body::Body,
hyper::{Request, Response, StatusCode},
pin_project::pin_project,
std::{
future::Future,
net::IpAddr,
pin::Pin,
task::{Context, Poll},
},
thiserror::Error,
tower::Service,
tower_layer::Layer,
tracing::{error, info},
};

#[cfg(test)]
mod tests;

/// Values used to configure the middleware behavior when country information
/// could not be retrieved.
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum BlockingPolicy {
Block,
AllowMissingCountry,
AllowExtractFailure,
AllowAll,
}

#[derive(Debug, Error)]
enum GeoBlockError {
#[error("Country is blocked")]
Blocked,

#[error("Unable to extract IP address")]
UnableToExtractIPAddress,

#[error("Unable to extract geo data from IP address")]
UnableToExtractGeoData,

#[error("Country could not be found in database")]
CountryNotFound,
}

/// Layer that applies the GeoBlock middleware which blocks requests base on IP
/// geo-location.
#[derive(Debug, Clone)]
#[must_use]
pub struct GeoBlockLayer<R>
where
R: GeoIpResolver,
{
blocked_countries: Vec<String>,
ip_resolver: R,
blocking_policy: BlockingPolicy,
}

impl<R> GeoBlockLayer<R>
where
R: GeoIpResolver,
{
pub fn new(
ip_resolver: R,
blocked_countries: Vec<String>,
blocking_policy: BlockingPolicy,
) -> Self {
Self {
ip_resolver,
blocked_countries,
blocking_policy,
}
}
}

impl<S, R> Layer<S> for GeoBlockLayer<R>
where
R: GeoIpResolver,
{
type Service = GeoBlockService<S, R>;

fn layer(&self, inner: S) -> Self::Service {
GeoBlockService::new(
inner,
self.ip_resolver.clone(),
self.blocked_countries.clone(),
self.blocking_policy,
)
}
}

/// Layer that applies the GeoBlock middleware which blocks requests base on IP
/// geo-location.
#[derive(Debug, Clone)]
#[must_use]
pub struct GeoBlockService<S, R>
where
R: GeoIpResolver,
{
inner: S,
blocked_countries: Vec<String>,
ip_resolver: R,
blocking_policy: BlockingPolicy,
}

impl<S, R> GeoBlockService<S, R>
where
R: GeoIpResolver,
{
pub fn new(
inner: S,
ip_resolver: R,
blocked_countries: Vec<String>,
blocking_policy: BlockingPolicy,
) -> Self {
Self {
inner,
blocking_policy,
blocked_countries,
ip_resolver,
}
}

/// Extracts the IP address from the request.
fn extract_ip<ReqBody>(&self, req: &Request<ReqBody>) -> Result<IpAddr, GeoBlockError> {
let client_ip = InsecureClientIp::from(req.headers(), req.extensions())
.map_err(|_| GeoBlockError::UnableToExtractIPAddress)?;
Ok(client_ip.0)
}

/// Checks if the specified IP address is allowed to access the service.
fn check_ip(&self, caller: IpAddr) -> Result<(), GeoBlockError> {
let country = self
.ip_resolver
.lookup_geo_data(caller)
.map_err(|_| GeoBlockError::UnableToExtractGeoData)?
.country
.ok_or(GeoBlockError::CountryNotFound)?;

if self
.blocked_countries
.iter()
.any(|blocked_country| *blocked_country == *country)
{
Err(GeoBlockError::Blocked)
} else {
Ok(())
}
}

fn check_caller<ReqBody>(&self, req: &Request<ReqBody>) -> Result<(), GeoBlockError> {
self.check_ip(self.extract_ip(req)?)
}
}

impl<S, R, ReqBody, ResBody> Service<Request<ReqBody>> for GeoBlockService<S, R>
where
S: Service<Request<ReqBody>, Response = Response<ResBody>>,
R: GeoIpResolver,
ResBody: Body + Default,
{
type Error = S::Error;
type Future = ResponseFuture<S::Future, ResBody>;
type Response = S::Response;

fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}

fn call(&mut self, request: Request<ReqBody>) -> Self::Future {
match self.check_caller(&request) {
Ok(_) => ResponseFuture::future(self.inner.call(request)),

Err(GeoBlockError::Blocked) => {
let mut res = Response::new(ResBody::default());
*res.status_mut() = StatusCode::UNAUTHORIZED;
ResponseFuture::error(res)
}

Err(GeoBlockError::UnableToExtractIPAddress)
| Err(GeoBlockError::UnableToExtractGeoData) => {
if self.blocking_policy == BlockingPolicy::AllowExtractFailure {
info!("Unable to extract client IP address");
return ResponseFuture::future(self.inner.call(request));
}

let mut res = Response::new(ResBody::default());
*res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
ResponseFuture::error(res)
}

Err(e) => {
if self.blocking_policy == BlockingPolicy::AllowMissingCountry {
error!("Unable to extract client IP address: {}", e);
return ResponseFuture::future(self.inner.call(request));
}

let mut res = Response::new(ResBody::default());
*res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
ResponseFuture::error(res)
}
}
}
}

/// Response future for [`GeoBlock`].
#[pin_project]
pub struct ResponseFuture<F, B> {
#[pin]
inner: Kind<F, B>,
}

#[pin_project(project = KindProj)]
enum Kind<F, B> {
Future {
#[pin]
future: F,
},
Error {
response: Option<Response<B>>,
},
}

impl<F, B> ResponseFuture<F, B> {
fn future(future: F) -> Self {
Self {
inner: Kind::Future { future },
}
}

fn error(res: Response<B>) -> Self {
Self {
inner: Kind::Error {
response: Some(res),
},
}
}
}

impl<F, B, E> Future for ResponseFuture<F, B>
where
F: Future<Output = Result<Response<B>, E>>,
{
type Output = F::Output;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.project().inner.project() {
KindProj::Future { future } => future.poll(cx),
KindProj::Error { response } => {
let response = response.take().unwrap();
Poll::Ready(Ok(response))
}
}
}
}
Loading

0 comments on commit 03ccdfd

Please sign in to comment.