Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add geoblock tower middleware #4

Merged
merged 7 commits into from
Sep 26, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,25 +20,31 @@ full = [
"future",
"http",
"metrics",
"geoblock"
]
alloc = ["dep:alloc"]
collections = ["dep:collections"]
future = ["dep:future"]
http = []
metrics = ["dep:metrics", "future/metrics", "alloc/metrics", "http/metrics"]
profiler = ["alloc/profiler"]
geoblock = ["dep:geoblock"]

[dependencies]
alloc = { path = "./crates/alloc", optional = true }
collections = { path = "./crates/collections", optional = true }
future = { path = "./crates/future", optional = true }
http = { path = "./crates/http", optional = true }
metrics = { path = "./crates/metrics", optional = true }
geoblock = { path = "./crates/geoblock", optional = true }

[dev-dependencies]
anyhow = "1"
structopt = { version = "0.3", default-features = false }
tokio = { version = "1", features = ["full"] }
hyper = { version = "0.14", features = ["full"] }
tower = { version = "0.4", features = ["util", "filter"] }
axum = "0.6.1"

[[example]]
name = "alloc_profiler"
Expand All @@ -51,3 +57,7 @@ required-features = ["alloc", "metrics"]
[[example]]
name = "metrics"
required-features = ["metrics", "future"]

[[example]]
name = "geoblock"
required-features = ["geoblock"]
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,16 @@ Metrics and other utils for HTTP servers.

Global service metrics. Currently based on `opentelemetry` SDK and exported in `prometheus` format.

## `geoblock`

Tower middleware for blocking requests based on clients' IP origin.

Note: this middleware requires you to use
[Router::into_make_service_with_connect_info](https://docs.rs/axum/latest/axum/struct.Router.html#method.into_make_service_with_connect_info)
to run your app otherwise it will fail at runtime.

See [Router::into_make_service_with_connect_info](https://docs.rs/axum/latest/axum/struct.Router.html#method.into_make_service_with_connect_info) for more details.

## Examples

- [Metrics integration](examples/metrics.rs). Prints service metrics in the default (`prometheus`) format.
Expand Down
25 changes: 25 additions & 0 deletions crates/geoblock/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
[package]
name = "geoblock"
version = "0.1.0"
edition = "2021"

[features]
default = []
full = ["tracing"]
tracing = ["dep:tracing"]

[dependencies]
axum = "0.6"
thiserror = "1.0"
hyper = "0.14"
tower = "0.4.13"
tower-layer = "0.3.2"
pin-project = "1.0.12"
futures-core = "0.3.28"
http-body = "0.4.5"
geoip = { git = "https://github.com/WalletConnect/geoip.git", tag = "0.2.0"}
tracing = { version = "0.1", optional = true }

[dev-dependencies]
tokio = { version = "1", features = ["full"] }
axum = "0.6.1"
287 changes: 287 additions & 0 deletions crates/geoblock/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,287 @@
/// Middleware which adds geo-location IP blocking.
xav marked this conversation as resolved.
Show resolved Hide resolved
///
/// Note: this middleware requires you to use
/// [Router::into_make_service_with_connect_info](https://docs.rs/axum/latest/axum/struct.Router.html#method.into_make_service_with_connect_info)
/// to run your app otherwise it will fail at runtime.
///
/// See [Router::into_make_service_with_connect_info](https://docs.rs/axum/latest/axum/struct.Router.html#method.into_make_service_with_connect_info) for more details.
pub use geoip;
xav marked this conversation as resolved.
Show resolved Hide resolved
xav marked this conversation as resolved.
Show resolved Hide resolved
#[cfg(feature = "tracing")]
use tracing::{error, info};
use {
axum::{extract::ConnectInfo, http::HeaderMap},
geoip::GeoIpResolver,
http_body::Body,
hyper::{Request, Response, StatusCode},
pin_project::pin_project,
std::{
future::Future,
net::{IpAddr, SocketAddr},
pin::Pin,
sync::Arc,
task::{Context, Poll},
},
thiserror::Error,
tower::{Layer, Service},
};

#[cfg(test)]
mod tests;

/// Values used to configure the middleware behavior when country information
/// could not be retrieved.
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum MissingCountry {
Allow,
Block,
}

#[derive(Debug, Error)]
pub enum GeoBlockError {
#[error("Country is blocked: {country}")]
BlockedCountry { country: Arc<str> },

#[error("Unable to extract IP address")]
UnableToExtractIPAddress,

#[error("Unable to extract geo data from IP address")]
UnableToExtractGeoData,

#[error("Country could not be found in database")]
CountryNotFound,

#[error("Other Error")]
Other {
code: StatusCode,
msg: Option<String>,
headers: Option<HeaderMap>,
},
}

/// Layer that applies the GeoBlock middleware which blocks requests base on IP
/// geo-location.
#[derive(Debug, Clone)]
#[must_use]
pub struct GeoBlockLayer<T>
where
T: GeoIpResolver,
{
missing_country: MissingCountry,
blocked_countries: Vec<Arc<str>>,
xav marked this conversation as resolved.
Show resolved Hide resolved
geoip: T,
}

impl<T> GeoBlockLayer<T>
where
T: GeoIpResolver,
{
pub fn new(
geoip: T,
blocked_countries: Vec<Arc<str>>,
missing_country: MissingCountry,
) -> Self {
Self {
missing_country,
blocked_countries,
geoip,
}
}
}

impl<S, T> Layer<S> for GeoBlockLayer<T>
where
T: GeoIpResolver,
{
type Service = GeoBlock<S, T>;

fn layer(&self, inner: S) -> Self::Service {
GeoBlock::new(
inner,
self.geoip.clone(),
self.blocked_countries.clone(),
self.missing_country,
)
}
}

#[derive(Clone, Debug)]
pub struct GeoBlock<S, R>
where
R: GeoIpResolver,
{
inner: S,
missing_country: MissingCountry,
blocked_countries: Vec<Arc<str>>,
geoip: R,
}

impl<S, R> GeoBlock<S, R>
where
R: GeoIpResolver,
{
fn new(
inner: S,
geoip: R,
blocked_countries: Vec<Arc<str>>,
missing_country: MissingCountry,
) -> Self {
Self {
inner,
missing_country,
blocked_countries,
geoip,
}
}

fn extract_ip<ReqBody>(&self, req: &Request<ReqBody>) -> Result<IpAddr, GeoBlockError> {
req.extensions()
.get::<ConnectInfo<SocketAddr>>()
xav marked this conversation as resolved.
Show resolved Hide resolved
.map(|ConnectInfo(addr)| addr.ip())
.ok_or(GeoBlockError::UnableToExtractIPAddress)
}

fn check_caller(&self, caller: IpAddr) -> Result<(), GeoBlockError> {
let country = match self
.geoip
.lookup_geo_data(caller)
.map_err(|_| GeoBlockError::UnableToExtractGeoData)
{
Ok(geo_data) => match geo_data.country {
None if self.missing_country == MissingCountry::Allow => {
#[cfg(feature = "tracing")]
xav marked this conversation as resolved.
Show resolved Hide resolved
{
info!("Country not found, but allowed");
}
return Ok(());
}
None => {
#[cfg(feature = "tracing")]
{
info!("Country not found");
}
return Err(GeoBlockError::CountryNotFound);
}
Some(country) => country,
},
Err(_e) => {
return if self.missing_country == MissingCountry::Allow {
Ok(())
} else {
#[cfg(feature = "tracing")]
{
error!("Unable to extract geo data from IP address: {}", _e);
}
Err(GeoBlockError::UnableToExtractGeoData)
}
}
};

let is_blocked = self
.blocked_countries
.iter()
.any(|blocked_country| *blocked_country == country);

if is_blocked {
Err(GeoBlockError::BlockedCountry { country })
} else {
Ok(())
}
}
}

impl<R, S, ReqBody, ResBody> Service<Request<ReqBody>> for GeoBlock<S, R>
where
R: GeoIpResolver,
S: Service<Request<ReqBody>, Response = Response<ResBody>>,
ResBody: Body + Default,
{
type Error = S::Error;
type Future = ResponseFuture<S::Future, ResBody>;
type Response = Response<ResBody>;

fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}

fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
match self.extract_ip(&req) {
xav marked this conversation as resolved.
Show resolved Hide resolved
Ok(ip_addr) => match self.check_caller(ip_addr) {
Ok(_) => ResponseFuture::future(self.inner.call(req)),
Err(GeoBlockError::BlockedCountry { country: _country }) => {
let mut res = Response::new(ResBody::default());
*res.status_mut() = StatusCode::UNAUTHORIZED;
ResponseFuture::invalid_ip(res)
}
Err(_e) => {
let mut res = Response::new(ResBody::default());
*res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
ResponseFuture::invalid_ip(res)
}
},
Err(_e) => {
if self.missing_country == MissingCountry::Allow {
#[cfg(feature = "tracing")]
{
error!("Unable to extract client IP address: {}", _e);
}
return ResponseFuture::future(self.inner.call(req));
}

let mut res = Response::new(ResBody::default());
*res.status_mut() = StatusCode::UNAUTHORIZED;
ResponseFuture::invalid_ip(res)
}
}
}
xav marked this conversation as resolved.
Show resolved Hide resolved
}

#[pin_project]
/// Response future for [`GeoBlock`].
pub struct ResponseFuture<F, B> {
#[pin]
inner: Kind<F, B>,
}

impl<F, B> ResponseFuture<F, B> {
fn future(future: F) -> Self {
Self {
inner: Kind::Future { future },
}
}

fn invalid_ip(res: Response<B>) -> Self {
Self {
inner: Kind::Error {
response: Some(res),
},
}
}
}

#[pin_project(project = KindProj)]
enum Kind<F, B> {
xav marked this conversation as resolved.
Show resolved Hide resolved
Future {
#[pin]
future: F,
},
Error {
response: Option<Response<B>>,
},
}

impl<F, B, E> Future for ResponseFuture<F, B>
where
F: Future<Output = Result<Response<B>, E>>,
{
type Output = F::Output;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.project().inner.project() {
KindProj::Future { future } => future.poll(cx),
KindProj::Error { response } => {
let response = response.take().unwrap();
Poll::Ready(Ok(response))
}
}
}
}
Loading