diff --git a/Cargo.toml b/Cargo.toml index bd95ad7..2ad6ccc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,6 +20,7 @@ full = [ "future", "http", "metrics", + "geoblock" ] alloc = ["dep:alloc"] collections = ["dep:collections"] @@ -27,6 +28,7 @@ future = ["dep:future"] http = [] metrics = ["dep:metrics", "future/metrics", "alloc/metrics", "http/metrics"] profiler = ["alloc/profiler"] +geoblock = ["dep:geoblock"] [dependencies] alloc = { path = "./crates/alloc", optional = true } @@ -34,11 +36,15 @@ collections = { path = "./crates/collections", optional = true } future = { path = "./crates/future", optional = true } http = { path = "./crates/http", optional = true } metrics = { path = "./crates/metrics", optional = true } +geoblock = { path = "./crates/geoblock", optional = true } [dev-dependencies] anyhow = "1" structopt = { version = "0.3", default-features = false } tokio = { version = "1", features = ["full"] } +hyper = { version = "0.14", features = ["full"] } +tower = { version = "0.4", features = ["util", "filter"] } +axum = "0.6.1" [[example]] name = "alloc_profiler" @@ -51,3 +57,7 @@ required-features = ["alloc", "metrics"] [[example]] name = "metrics" required-features = ["metrics", "future"] + +[[example]] +name = "geoblock" +required-features = ["geoblock"] diff --git a/README.md b/README.md index e8cb0fe..1f0cc48 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,16 @@ Metrics and other utils for HTTP servers. Global service metrics. Currently based on `opentelemetry` SDK and exported in `prometheus` format. +## `geoblock` + +Tower middleware for blocking requests based on clients' IP origin. + +Note: this middleware requires you to use +[Router::into_make_service_with_connect_info](https://docs.rs/axum/latest/axum/struct.Router.html#method.into_make_service_with_connect_info) +to run your app otherwise it will fail at runtime. + +See [Router::into_make_service_with_connect_info](https://docs.rs/axum/latest/axum/struct.Router.html#method.into_make_service_with_connect_info) for more details. + ## Examples - [Metrics integration](examples/metrics.rs). Prints service metrics in the default (`prometheus`) format. diff --git a/crates/geoblock/Cargo.toml b/crates/geoblock/Cargo.toml new file mode 100644 index 0000000..6aee1cd --- /dev/null +++ b/crates/geoblock/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "geoblock" +version = "0.1.0" +edition = "2021" + +[features] +default = [] + +[dependencies] +tracing = "0.1" +axum = "0.6" +thiserror = "1.0" +hyper = "0.14" +tower = "0.4" +tower-layer = "0.3" +pin-project = "1" +futures-core = "0.3" +http-body = "0.4" +axum-client-ip = "0.4" +geoip = { git = "https://github.com/WalletConnect/geoip.git", tag = "0.2.0"} + +[dev-dependencies] +tokio = { version = "1", features = ["full"] } +axum = "0.6" diff --git a/crates/geoblock/src/lib.rs b/crates/geoblock/src/lib.rs new file mode 100644 index 0000000..f9caafa --- /dev/null +++ b/crates/geoblock/src/lib.rs @@ -0,0 +1,264 @@ +//! Middleware which adds geo-location IP blocking. +//! +//! Note: this middleware requires you to use +//! [Router::into_make_service_with_connect_info](https://docs.rs/axum/latest/axum/struct.Router.html#method.into_make_service_with_connect_info) +//! to run your app otherwise it will fail at runtime. +//! +//! See [Router::into_make_service_with_connect_info](https://docs.rs/axum/latest/axum/struct.Router.html#method.into_make_service_with_connect_info) for more details. +pub use geoip; +use { + axum_client_ip::InsecureClientIp, + geoip::GeoIpResolver, + http_body::Body, + hyper::{Request, Response, StatusCode}, + pin_project::pin_project, + std::{ + future::Future, + net::IpAddr, + pin::Pin, + task::{Context, Poll}, + }, + thiserror::Error, + tower::Service, + tower_layer::Layer, + tracing::{error, info}, +}; + +#[cfg(test)] +mod tests; + +/// Values used to configure the middleware behavior when country information +/// could not be retrieved. +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum BlockingPolicy { + Block, + AllowMissingCountry, + AllowExtractFailure, + AllowAll, +} + +#[derive(Debug, Error)] +enum GeoBlockError { + #[error("Country is blocked")] + Blocked, + + #[error("Unable to extract IP address")] + UnableToExtractIPAddress, + + #[error("Unable to extract geo data from IP address")] + UnableToExtractGeoData, + + #[error("Country could not be found in database")] + CountryNotFound, +} + +/// Layer that applies the GeoBlock middleware which blocks requests base on IP +/// geo-location. +#[derive(Debug, Clone)] +#[must_use] +pub struct GeoBlockLayer +where + R: GeoIpResolver, +{ + blocked_countries: Vec, + ip_resolver: R, + blocking_policy: BlockingPolicy, +} + +impl GeoBlockLayer +where + R: GeoIpResolver, +{ + pub fn new( + ip_resolver: R, + blocked_countries: Vec, + blocking_policy: BlockingPolicy, + ) -> Self { + Self { + ip_resolver, + blocked_countries, + blocking_policy, + } + } +} + +impl Layer for GeoBlockLayer +where + R: GeoIpResolver, +{ + type Service = GeoBlockService; + + fn layer(&self, inner: S) -> Self::Service { + GeoBlockService::new( + inner, + self.ip_resolver.clone(), + self.blocked_countries.clone(), + self.blocking_policy, + ) + } +} + +/// Layer that applies the GeoBlock middleware which blocks requests base on IP +/// geo-location. +#[derive(Debug, Clone)] +#[must_use] +pub struct GeoBlockService +where + R: GeoIpResolver, +{ + inner: S, + blocked_countries: Vec, + ip_resolver: R, + blocking_policy: BlockingPolicy, +} + +impl GeoBlockService +where + R: GeoIpResolver, +{ + pub fn new( + inner: S, + ip_resolver: R, + blocked_countries: Vec, + blocking_policy: BlockingPolicy, + ) -> Self { + Self { + inner, + blocking_policy, + blocked_countries, + ip_resolver, + } + } + + /// Extracts the IP address from the request. + fn extract_ip(&self, req: &Request) -> Result { + let client_ip = InsecureClientIp::from(req.headers(), req.extensions()) + .map_err(|_| GeoBlockError::UnableToExtractIPAddress)?; + Ok(client_ip.0) + } + + /// Checks if the specified IP address is allowed to access the service. + fn check_ip(&self, caller: IpAddr) -> Result<(), GeoBlockError> { + let country = self + .ip_resolver + .lookup_geo_data(caller) + .map_err(|_| GeoBlockError::UnableToExtractGeoData)? + .country + .ok_or(GeoBlockError::CountryNotFound)?; + + if self + .blocked_countries + .iter() + .any(|blocked_country| *blocked_country == *country) + { + Err(GeoBlockError::Blocked) + } else { + Ok(()) + } + } + + fn check_caller(&self, req: &Request) -> Result<(), GeoBlockError> { + self.check_ip(self.extract_ip(req)?) + } +} + +impl Service> for GeoBlockService +where + S: Service, Response = Response>, + R: GeoIpResolver, + ResBody: Body + Default, +{ + type Error = S::Error; + type Future = ResponseFuture; + type Response = S::Response; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, request: Request) -> Self::Future { + match self.check_caller(&request) { + Ok(_) => ResponseFuture::future(self.inner.call(request)), + + Err(GeoBlockError::Blocked) => { + let mut res = Response::new(ResBody::default()); + *res.status_mut() = StatusCode::UNAUTHORIZED; + ResponseFuture::error(res) + } + + Err(GeoBlockError::UnableToExtractIPAddress) + | Err(GeoBlockError::UnableToExtractGeoData) => { + if self.blocking_policy == BlockingPolicy::AllowExtractFailure { + info!("Unable to extract client IP address"); + return ResponseFuture::future(self.inner.call(request)); + } + + let mut res = Response::new(ResBody::default()); + *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; + ResponseFuture::error(res) + } + + Err(e) => { + if self.blocking_policy == BlockingPolicy::AllowMissingCountry { + error!("Unable to extract client IP address: {}", e); + return ResponseFuture::future(self.inner.call(request)); + } + + let mut res = Response::new(ResBody::default()); + *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; + ResponseFuture::error(res) + } + } + } +} + +/// Response future for [`GeoBlock`]. +#[pin_project] +pub struct ResponseFuture { + #[pin] + inner: Kind, +} + +#[pin_project(project = KindProj)] +enum Kind { + Future { + #[pin] + future: F, + }, + Error { + response: Option>, + }, +} + +impl ResponseFuture { + fn future(future: F) -> Self { + Self { + inner: Kind::Future { future }, + } + } + + fn error(res: Response) -> Self { + Self { + inner: Kind::Error { + response: Some(res), + }, + } + } +} + +impl Future for ResponseFuture +where + F: Future, E>>, +{ + type Output = F::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.project().inner.project() { + KindProj::Future { future } => future.poll(cx), + KindProj::Error { response } => { + let response = response.take().unwrap(); + Poll::Ready(Ok(response)) + } + } + } +} diff --git a/crates/geoblock/src/tests.rs b/crates/geoblock/src/tests.rs new file mode 100644 index 0000000..51d9a78 --- /dev/null +++ b/crates/geoblock/src/tests.rs @@ -0,0 +1,69 @@ +use { + crate::{geoip, BlockingPolicy, GeoBlockLayer}, + hyper::{Body, Request, Response, StatusCode}, + std::{ + convert::Infallible, + net::{IpAddr, Ipv4Addr}, + }, + tower::{Service, ServiceBuilder, ServiceExt}, +}; + +async fn handle(_request: Request) -> Result, Infallible> { + Ok(Response::new(Body::empty())) +} + +fn resolve_ip(caller: IpAddr) -> geoip::GeoData { + if IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)) == caller { + geoip::GeoData { + continent: Some("NA".to_string().into()), + country: Some("CU".to_string().into()), + region: None, + city: None, + } + } else { + geoip::GeoData { + continent: Some("NA".to_string().into()), + country: Some("US".to_string().into()), + region: None, + city: None, + } + } +} + +#[tokio::test] +async fn test_blocked_country() { + let resolver: geoip::local::LocalResolver = geoip::local::LocalResolver::new(resolve_ip); + let blocked_countries = vec!["CU".into(), "IR".into(), "KP".into()]; + + let geoblock = GeoBlockLayer::new(resolver, blocked_countries, BlockingPolicy::Block); + + let mut service = ServiceBuilder::new().layer(geoblock).service_fn(handle); + + let request = Request::builder() + .header("X-Forwarded-For", "127.0.0.1") + .body(Body::empty()) + .unwrap(); + + let response = service.ready().await.unwrap().call(request).await.unwrap(); + + assert_eq!(response.status(), StatusCode::UNAUTHORIZED); +} + +#[tokio::test] +async fn test_non_blocked_country() { + let resolver: geoip::local::LocalResolver = geoip::local::LocalResolver::new(resolve_ip); + let blocked_countries = vec!["IR".into(), "KP".into()]; + + let geoblock = GeoBlockLayer::new(resolver, blocked_countries, BlockingPolicy::Block); + + let mut service = ServiceBuilder::new().layer(geoblock).service_fn(handle); + + let request = Request::builder() + .header("X-Forwarded-For", "127.0.0.1") + .body(Body::empty()) + .unwrap(); + + let response = service.ready().await.unwrap().call(request).await.unwrap(); + + assert_eq!(response.status(), StatusCode::OK); +} diff --git a/examples/geoblock.rs b/examples/geoblock.rs new file mode 100644 index 0000000..1b833b1 --- /dev/null +++ b/examples/geoblock.rs @@ -0,0 +1,53 @@ +use { + geoblock::{geoip::GeoData, BlockingPolicy, GeoBlockLayer}, + hyper::{Body, Request, Response, StatusCode}, + std::{ + convert::Infallible, + net::{IpAddr, Ipv4Addr}, + }, + tower::{Service, ServiceBuilder, ServiceExt}, + wc::geoblock::geoip::local::LocalResolver, +}; + +async fn handle(_request: Request) -> Result, Infallible> { + Ok(Response::new(Body::empty())) +} + +fn resolve_ip(caller: IpAddr) -> GeoData { + if IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)) == caller { + GeoData { + continent: Some("NA".to_string().into()), + country: Some("CU".to_string().into()), + region: None, + city: None, + } + } else { + GeoData { + continent: Some("NA".to_string().into()), + country: Some("US".to_string().into()), + region: None, + city: None, + } + } +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + let resolver: LocalResolver = LocalResolver::new(|caller| resolve_ip(caller)); + let blocked_countries = vec!["CU".into(), "IR".into(), "KP".into()]; + + let geoblock = GeoBlockLayer::new(resolver, blocked_countries, BlockingPolicy::Block); + + let mut service = ServiceBuilder::new().layer(geoblock).service_fn(handle); + + let request = Request::builder() + .header("X-Forwarded-For", "127.0.0.1") + .body(Body::empty()) + .unwrap(); + + let response = service.ready().await.unwrap().call(request).await.unwrap(); + + assert_eq!(response.status(), StatusCode::UNAUTHORIZED); + + Ok(()) +} diff --git a/src/lib.rs b/src/lib.rs index f5fc637..dc3154d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,6 +5,8 @@ pub use alloc; pub use collections; #[cfg(feature = "future")] pub use future; +#[cfg(feature = "geoblock")] +pub use geoblock; #[cfg(feature = "http")] pub use http; #[cfg(feature = "metrics")]