Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add geoblock tower middleware #4

Merged
merged 7 commits into from
Sep 26, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,25 +20,31 @@ full = [
"future",
"http",
"metrics",
"geoblock"
]
alloc = ["dep:alloc"]
collections = ["dep:collections"]
future = ["dep:future"]
http = []
metrics = ["dep:metrics", "future/metrics", "alloc/metrics", "http/metrics"]
profiler = ["alloc/profiler"]
geoblock = ["dep:geoblock"]

[dependencies]
alloc = { path = "./crates/alloc", optional = true }
collections = { path = "./crates/collections", optional = true }
future = { path = "./crates/future", optional = true }
http = { path = "./crates/http", optional = true }
metrics = { path = "./crates/metrics", optional = true }
geoblock = { path = "./crates/geoblock", optional = true }

[dev-dependencies]
anyhow = "1"
structopt = { version = "0.3", default-features = false }
tokio = { version = "1", features = ["full"] }
hyper = { version = "0.14", features = ["full"] }
tower = { version = "0.4", features = ["util", "filter"] }
axum = "0.6.1"

[[example]]
name = "alloc_profiler"
Expand All @@ -51,3 +57,7 @@ required-features = ["alloc", "metrics"]
[[example]]
name = "metrics"
required-features = ["metrics", "future"]

[[example]]
name = "geoblock"
required-features = ["geoblock"]
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,16 @@ Metrics and other utils for HTTP servers.

Global service metrics. Currently based on `opentelemetry` SDK and exported in `prometheus` format.

## `geoblock`

Tower middleware for blocking requests based on clients' IP origin.

Note: this middleware requires you to use
[Router::into_make_service_with_connect_info](https://docs.rs/axum/latest/axum/struct.Router.html#method.into_make_service_with_connect_info)
to run your app otherwise it will fail at runtime.

See [Router::into_make_service_with_connect_info](https://docs.rs/axum/latest/axum/struct.Router.html#method.into_make_service_with_connect_info) for more details.

## Examples

- [Metrics integration](examples/metrics.rs). Prints service metrics in the default (`prometheus`) format.
Expand Down
24 changes: 24 additions & 0 deletions crates/geoblock/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
[package]
name = "geoblock"
version = "0.1.0"
edition = "2021"

[features]
default = []

[dependencies]
tracing = "0.1"
axum = "0.6"
thiserror = "1.0"
hyper = "0.14"
tower = "0.4"
tower-layer = "0.3"
pin-project = "1"
futures-core = "0.3"
http-body = "0.4"
axum-client-ip = "0.4"
geoip = { git = "https://github.com/WalletConnect/geoip.git", tag = "0.2.0"}

[dev-dependencies]
tokio = { version = "1", features = ["full"] }
axum = "0.6"
265 changes: 265 additions & 0 deletions crates/geoblock/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,265 @@
//! Middleware which adds geo-location IP blocking.
//!
//! Note: this middleware requires you to use
//! [Router::into_make_service_with_connect_info](https://docs.rs/axum/latest/axum/struct.Router.html#method.into_make_service_with_connect_info)
//! to run your app otherwise it will fail at runtime.
//!
//! See [Router::into_make_service_with_connect_info](https://docs.rs/axum/latest/axum/struct.Router.html#method.into_make_service_with_connect_info) for more details.
pub use geoip;
xav marked this conversation as resolved.
Show resolved Hide resolved
xav marked this conversation as resolved.
Show resolved Hide resolved
use {
axum_client_ip::InsecureClientIp,
geoip::GeoIpResolver,
http_body::Body,
hyper::{Request, Response, StatusCode},
pin_project::pin_project,
std::{
future::Future,
net::IpAddr,
pin::Pin,
sync::Arc,
task::{Context, Poll},
},
thiserror::Error,
tower::Service,
tower_layer::Layer,
tracing::{error, info},
};

#[cfg(test)]
mod tests;

/// Values used to configure the middleware behavior when country information
/// could not be retrieved.
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum BlockingPolicy {
Block,
AllowMissingCountry,
AllowExtractFailure,
AllowAll,
}

#[derive(Debug, Error)]
enum GeoBlockError {
#[error("Country is blocked")]
Blocked,

#[error("Unable to extract IP address")]
UnableToExtractIPAddress,

#[error("Unable to extract geo data from IP address")]
UnableToExtractGeoData,

#[error("Country could not be found in database")]
CountryNotFound,
}

/// Layer that applies the GeoBlock middleware which blocks requests base on IP
/// geo-location.
#[derive(Debug, Clone)]
#[must_use]
pub struct GeoBlockLayer<R>
where
R: GeoIpResolver,
{
blocked_countries: Vec<Arc<str>>,
xav marked this conversation as resolved.
Show resolved Hide resolved
ip_resolver: R,
blocking_policy: BlockingPolicy,
}

impl<R> GeoBlockLayer<R>
where
R: GeoIpResolver,
{
pub fn new(
ip_resolver: R,
blocked_countries: Vec<Arc<str>>,
blocking_policy: BlockingPolicy,
) -> Self {
Self {
ip_resolver,
blocked_countries,
blocking_policy,
}
}
}

impl<S, R> Layer<S> for GeoBlockLayer<R>
where
R: GeoIpResolver,
{
type Service = GeoBlockService<S, R>;

fn layer(&self, inner: S) -> Self::Service {
GeoBlockService::new(
inner,
self.ip_resolver.clone(),
self.blocked_countries.clone(),
self.blocking_policy,
)
}
}

/// Layer that applies the GeoBlock middleware which blocks requests base on IP
/// geo-location.
#[derive(Debug, Clone)]
#[must_use]
pub struct GeoBlockService<S, R>
where
R: GeoIpResolver,
{
inner: S,
blocked_countries: Vec<Arc<str>>,
ip_resolver: R,
blocking_policy: BlockingPolicy,
}

impl<S, R> GeoBlockService<S, R>
where
R: GeoIpResolver,
{
pub fn new(
inner: S,
ip_resolver: R,
blocked_countries: Vec<Arc<str>>,
blocking_policy: BlockingPolicy,
) -> Self {
Self {
inner,
blocking_policy,
blocked_countries,
ip_resolver,
}
}

/// Extracts the IP address from the request.
fn extract_ip<ReqBody>(&self, req: &Request<ReqBody>) -> Result<IpAddr, GeoBlockError> {
let client_ip = InsecureClientIp::from(req.headers(), req.extensions())
.map_err(|_| GeoBlockError::UnableToExtractIPAddress)?;
Ok(client_ip.0)
}

/// Checks if the specified IP address is allowed to access the service.
fn check_ip(&self, caller: IpAddr) -> Result<(), GeoBlockError> {
let country = self
.ip_resolver
.lookup_geo_data(caller)
.map_err(|_| GeoBlockError::UnableToExtractGeoData)?
.country
.ok_or(GeoBlockError::CountryNotFound)?;

if self
.blocked_countries
.iter()
.any(|blocked_country| *blocked_country == country)
{
Err(GeoBlockError::Blocked)
} else {
Ok(())
}
}

fn check_caller<ReqBody>(&self, req: &Request<ReqBody>) -> Result<(), GeoBlockError> {
self.check_ip(self.extract_ip(req)?)
}
}

impl<S, R, ReqBody, ResBody> Service<Request<ReqBody>> for GeoBlockService<S, R>
where
S: Service<Request<ReqBody>, Response = Response<ResBody>>,
R: GeoIpResolver,
ResBody: Body + Default,
{
type Error = S::Error;
type Future = ResponseFuture<S::Future, ResBody>;
type Response = S::Response;

fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}

fn call(&mut self, request: Request<ReqBody>) -> Self::Future {
match self.check_caller(&request) {
Ok(_) => ResponseFuture::future(self.inner.call(request)),

Err(GeoBlockError::Blocked) => {
let mut res = Response::new(ResBody::default());
*res.status_mut() = StatusCode::UNAUTHORIZED;
ResponseFuture::error(res)
}

Err(GeoBlockError::UnableToExtractIPAddress)
| Err(GeoBlockError::UnableToExtractGeoData) => {
if self.blocking_policy == BlockingPolicy::AllowExtractFailure {
info!("Unable to extract client IP address");
return ResponseFuture::future(self.inner.call(request));
}

let mut res = Response::new(ResBody::default());
*res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
ResponseFuture::error(res)
}

Err(e) => {
if self.blocking_policy == BlockingPolicy::AllowMissingCountry {
error!("Unable to extract client IP address: {}", e);
return ResponseFuture::future(self.inner.call(request));
}

let mut res = Response::new(ResBody::default());
*res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
ResponseFuture::error(res)
}
}
}
xav marked this conversation as resolved.
Show resolved Hide resolved
}

/// Response future for [`GeoBlock`].
#[pin_project]
pub struct ResponseFuture<F, B> {
#[pin]
inner: Kind<F, B>,
}

#[pin_project(project = KindProj)]
enum Kind<F, B> {
Future {
#[pin]
future: F,
},
Error {
response: Option<Response<B>>,
},
}

impl<F, B> ResponseFuture<F, B> {
fn future(future: F) -> Self {
Self {
inner: Kind::Future { future },
}
}

fn error(res: Response<B>) -> Self {
Self {
inner: Kind::Error {
response: Some(res),
},
}
}
}

impl<F, B, E> Future for ResponseFuture<F, B>
where
F: Future<Output = Result<Response<B>, E>>,
{
type Output = F::Output;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.project().inner.project() {
KindProj::Future { future } => future.poll(cx),
KindProj::Error { response } => {
let response = response.take().unwrap();
Poll::Ready(Ok(response))
}
}
}
}
Loading