From 0564ba5a5d4f7bd4974d00b8fc57e7d9259eb685 Mon Sep 17 00:00:00 2001 From: Xavier Basty Date: Fri, 22 Sep 2023 16:07:38 +0200 Subject: [PATCH 1/7] feat: add `geoblock` tower middleware --- Cargo.toml | 10 ++ README.md | 4 + crates/geoblock/Cargo.toml | 20 ++++ crates/geoblock/src/errors.rs | 26 +++++ crates/geoblock/src/lib.rs | 193 ++++++++++++++++++++++++++++++++++ examples/geoblock.rs | 54 ++++++++++ src/lib.rs | 2 + 7 files changed, 309 insertions(+) create mode 100644 crates/geoblock/Cargo.toml create mode 100644 crates/geoblock/src/errors.rs create mode 100644 crates/geoblock/src/lib.rs create mode 100644 examples/geoblock.rs diff --git a/Cargo.toml b/Cargo.toml index bd95ad7..2ad6ccc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,6 +20,7 @@ full = [ "future", "http", "metrics", + "geoblock" ] alloc = ["dep:alloc"] collections = ["dep:collections"] @@ -27,6 +28,7 @@ future = ["dep:future"] http = [] metrics = ["dep:metrics", "future/metrics", "alloc/metrics", "http/metrics"] profiler = ["alloc/profiler"] +geoblock = ["dep:geoblock"] [dependencies] alloc = { path = "./crates/alloc", optional = true } @@ -34,11 +36,15 @@ collections = { path = "./crates/collections", optional = true } future = { path = "./crates/future", optional = true } http = { path = "./crates/http", optional = true } metrics = { path = "./crates/metrics", optional = true } +geoblock = { path = "./crates/geoblock", optional = true } [dev-dependencies] anyhow = "1" structopt = { version = "0.3", default-features = false } tokio = { version = "1", features = ["full"] } +hyper = { version = "0.14", features = ["full"] } +tower = { version = "0.4", features = ["util", "filter"] } +axum = "0.6.1" [[example]] name = "alloc_profiler" @@ -51,3 +57,7 @@ required-features = ["alloc", "metrics"] [[example]] name = "metrics" required-features = ["metrics", "future"] + +[[example]] +name = "geoblock" +required-features = ["geoblock"] diff --git a/README.md b/README.md index e8cb0fe..a877602 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,10 @@ Metrics and other utils for HTTP servers. Global service metrics. Currently based on `opentelemetry` SDK and exported in `prometheus` format. +## `geoblock` + +Tower middleware for blocking requests based on clients' IP origin. + ## Examples - [Metrics integration](examples/metrics.rs). Prints service metrics in the default (`prometheus`) format. diff --git a/crates/geoblock/Cargo.toml b/crates/geoblock/Cargo.toml new file mode 100644 index 0000000..28380b0 --- /dev/null +++ b/crates/geoblock/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "geoblock" +version = "0.1.0" +edition = "2021" + +[features] +default = [] +#full = ["metrics"] +#metrics = ["dep:metrics", "dep:future"] + +[dependencies] +axum = "0.6.1" +thiserror = "1.0" +hyper = "0.14" +tower = "0.4.13" +tower-layer = "0.3.2" +pin-project = "1.0.12" +futures-core = "0.3.28" +http-body = "0.4.5" +geoip = { git = "https://github.com/WalletConnect/geoip.git", tag = "0.2.0"} diff --git a/crates/geoblock/src/errors.rs b/crates/geoblock/src/errors.rs new file mode 100644 index 0000000..f67b5c4 --- /dev/null +++ b/crates/geoblock/src/errors.rs @@ -0,0 +1,26 @@ +use { + hyper::{HeaderMap, StatusCode}, + thiserror::Error, +}; + +#[derive(Debug, Error)] +pub enum GeoBlockError { + #[error("Country is blocked: {country}")] + BlockedCountry { country: String }, + + #[error("Unable to extract IP address")] + UnableToExtractIPAddress, + + #[error("Unable to extract geo data from IP address")] + UnableToExtractGeoData, + + #[error("Country could not be found in database")] + CountryNotFound, + + #[error("Other Error")] + Other { + code: StatusCode, + msg: Option, + headers: Option, + }, +} diff --git a/crates/geoblock/src/lib.rs b/crates/geoblock/src/lib.rs new file mode 100644 index 0000000..a954f2e --- /dev/null +++ b/crates/geoblock/src/lib.rs @@ -0,0 +1,193 @@ +// Middleware which adds geo-location IP blocking. + +pub use geoip; +use { + crate::errors::GeoBlockError, + axum::extract::ConnectInfo, + geoip::GeoIpResolver, + http_body::Body, + hyper::{Request, Response, StatusCode}, + pin_project::pin_project, + std::{ + future::Future, + net::{IpAddr, SocketAddr}, + pin::Pin, + task::{Context, Poll}, + }, + tower::{Layer, Service}, +}; + +pub mod errors; + +/// Layer that applies the GeoBlock middleware which blocks requests base on IP +/// geo-location. +#[derive(Debug, Clone)] +#[must_use] +pub struct GeoBlockLayer +where + T: GeoIpResolver, +{ + blocked_countries: Vec, + geoip: T, +} + +impl GeoBlockLayer +where + T: GeoIpResolver, +{ + pub fn new(geoip: T, blocked_countries: Vec) -> Self { + Self { + blocked_countries, + geoip, + } + } +} + +impl Layer for GeoBlockLayer +where + T: GeoIpResolver, +{ + type Service = GeoBlock; + + fn layer(&self, inner: S) -> Self::Service { + GeoBlock::new(inner, self.geoip.clone(), self.blocked_countries.clone()) + } +} + +#[derive(Clone, Debug)] +pub struct GeoBlock +where + R: GeoIpResolver, +{ + inner: S, + blocked_countries: Vec, + geoip: R, +} + +impl GeoBlock +where + R: GeoIpResolver, +{ + fn new(inner: S, geoip: R, blocked_countries: Vec) -> Self { + Self { + inner, + blocked_countries, + geoip, + } + } + + fn extract_ip(&self, req: &Request) -> Result { + req.extensions() + .get::>() + .map(|ConnectInfo(addr)| addr.ip()) + .ok_or(GeoBlockError::UnableToExtractIPAddress) + } + + fn check_caller(&self, caller: IpAddr) -> Result<(), GeoBlockError> { + let geo_data = self + .geoip + .lookup_geo_data(caller) + .map_err(|_| GeoBlockError::UnableToExtractGeoData)?; + + // TODO: let configure how to handle missing country + let country = geo_data + .country + .ok_or(GeoBlockError::CountryNotFound)? + .to_lowercase(); + + let is_blocked = self + .blocked_countries + .iter() + .any(|blocked_country| blocked_country == &country); + + if is_blocked { + Err(GeoBlockError::BlockedCountry { country }) + } else { + Ok(()) + } + } +} + +impl Service> for GeoBlock +where + R: GeoIpResolver, + S: Service, Response = Response>, + ResBody: Body + Default, +{ + type Error = S::Error; + type Future = ResponseFuture; + type Response = Response; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: Request) -> Self::Future { + match self.extract_ip(&req) { + Ok(ip_addr) => match self.check_caller(ip_addr) { + Ok(_) => ResponseFuture::future(self.inner.call(req)), + Err(_e) => { + let mut res = Response::new(ResBody::default()); + *res.status_mut() = StatusCode::UNAUTHORIZED; + ResponseFuture::invalid_ip(res) + } + }, + Err(_e) => { + let mut res = Response::new(ResBody::default()); + *res.status_mut() = StatusCode::UNAUTHORIZED; + ResponseFuture::invalid_ip(res) + } + } + } +} + +#[pin_project] +/// Response future for [`GeoBlock`]. +pub struct ResponseFuture { + #[pin] + inner: Kind, +} + +impl ResponseFuture { + fn future(future: F) -> Self { + Self { + inner: Kind::Future { future }, + } + } + + fn invalid_ip(res: Response) -> Self { + Self { + inner: Kind::Error { + response: Some(res), + }, + } + } +} + +#[pin_project(project = KindProj)] +enum Kind { + Future { + #[pin] + future: F, + }, + Error { + response: Option>, + }, +} + +impl Future for ResponseFuture +where + F: Future, E>>, +{ + type Output = F::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.project().inner.project() { + KindProj::Future { future } => future.poll(cx), + KindProj::Error { response } => { + let response = response.take().unwrap(); + Poll::Ready(Ok(response)) + } + } + } +} diff --git a/examples/geoblock.rs b/examples/geoblock.rs new file mode 100644 index 0000000..5791301 --- /dev/null +++ b/examples/geoblock.rs @@ -0,0 +1,54 @@ +use { + geoblock::geoip::GeoData, + hyper::{Body, Request, Response, StatusCode}, + std::{ + convert::Infallible, + net::{IpAddr, Ipv4Addr}, + }, + tower::{Service, ServiceBuilder, ServiceExt}, + wc::geoblock::{geoip::local::LocalResolver, GeoBlockLayer}, +}; + +async fn handle(_request: Request) -> Result, Infallible> { + Ok(Response::new(Body::empty())) +} + +fn resolve_ip(caller: IpAddr) -> GeoData { + if IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)) == caller { + GeoData { + continent: Some("Asia".to_string().into()), + country: Some("Derkaderkastan".to_string().into()), + region: None, + city: None, + } + } else { + GeoData { + continent: Some("North America".to_string().into()), + country: Some("United States".to_string().into()), + region: None, + city: None, + } + } +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + let resolver: LocalResolver = LocalResolver::new(|caller| resolve_ip(caller)); + let blocked_countries = vec![ + "Derkaderkastan".to_string(), + "Quran".to_string(), + "Tristan".to_string(), + ]; + + let geoblock = GeoBlockLayer::new(resolver, blocked_countries); + + let mut service = ServiceBuilder::new().layer(geoblock).service_fn(handle); + + let request = Request::builder().body(Body::empty()).unwrap(); + + let response = service.ready().await?.call(request).await?; + + assert_eq!(response.status(), StatusCode::UNAUTHORIZED); + + Ok(()) +} diff --git a/src/lib.rs b/src/lib.rs index f5fc637..dc3154d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,6 +5,8 @@ pub use alloc; pub use collections; #[cfg(feature = "future")] pub use future; +#[cfg(feature = "geoblock")] +pub use geoblock; #[cfg(feature = "http")] pub use http; #[cfg(feature = "metrics")] From 8f2a43d2960eb60eea18edfa225d0f13366e4163 Mon Sep 17 00:00:00 2001 From: Xavier Basty Date: Mon, 25 Sep 2023 16:39:55 +0200 Subject: [PATCH 2/7] fix: PR comments --- README.md | 6 ++ crates/geoblock/Cargo.toml | 11 ++- crates/geoblock/src/errors.rs | 26 ------- crates/geoblock/src/lib.rs | 134 +++++++++++++++++++++++++++++----- crates/geoblock/src/tests.rs | 67 +++++++++++++++++ examples/geoblock.rs | 10 +-- 6 files changed, 198 insertions(+), 56 deletions(-) delete mode 100644 crates/geoblock/src/errors.rs create mode 100644 crates/geoblock/src/tests.rs diff --git a/README.md b/README.md index a877602..1f0cc48 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,12 @@ Global service metrics. Currently based on `opentelemetry` SDK and exported in ` Tower middleware for blocking requests based on clients' IP origin. +Note: this middleware requires you to use +[Router::into_make_service_with_connect_info](https://docs.rs/axum/latest/axum/struct.Router.html#method.into_make_service_with_connect_info) +to run your app otherwise it will fail at runtime. + +See [Router::into_make_service_with_connect_info](https://docs.rs/axum/latest/axum/struct.Router.html#method.into_make_service_with_connect_info) for more details. + ## Examples - [Metrics integration](examples/metrics.rs). Prints service metrics in the default (`prometheus`) format. diff --git a/crates/geoblock/Cargo.toml b/crates/geoblock/Cargo.toml index 28380b0..2ffaf43 100644 --- a/crates/geoblock/Cargo.toml +++ b/crates/geoblock/Cargo.toml @@ -5,11 +5,11 @@ edition = "2021" [features] default = [] -#full = ["metrics"] -#metrics = ["dep:metrics", "dep:future"] +full = ["tracing"] +tracing = ["dep:tracing"] [dependencies] -axum = "0.6.1" +axum = "0.6" thiserror = "1.0" hyper = "0.14" tower = "0.4.13" @@ -18,3 +18,8 @@ pin-project = "1.0.12" futures-core = "0.3.28" http-body = "0.4.5" geoip = { git = "https://github.com/WalletConnect/geoip.git", tag = "0.2.0"} +tracing = { version = "0.1", optional = true } + +[dev-dependencies] +tokio = { version = "1", features = ["full"] } +axum = "0.6.1" diff --git a/crates/geoblock/src/errors.rs b/crates/geoblock/src/errors.rs deleted file mode 100644 index f67b5c4..0000000 --- a/crates/geoblock/src/errors.rs +++ /dev/null @@ -1,26 +0,0 @@ -use { - hyper::{HeaderMap, StatusCode}, - thiserror::Error, -}; - -#[derive(Debug, Error)] -pub enum GeoBlockError { - #[error("Country is blocked: {country}")] - BlockedCountry { country: String }, - - #[error("Unable to extract IP address")] - UnableToExtractIPAddress, - - #[error("Unable to extract geo data from IP address")] - UnableToExtractGeoData, - - #[error("Country could not be found in database")] - CountryNotFound, - - #[error("Other Error")] - Other { - code: StatusCode, - msg: Option, - headers: Option, - }, -} diff --git a/crates/geoblock/src/lib.rs b/crates/geoblock/src/lib.rs index a954f2e..f517144 100644 --- a/crates/geoblock/src/lib.rs +++ b/crates/geoblock/src/lib.rs @@ -1,9 +1,15 @@ -// Middleware which adds geo-location IP blocking. - +/// Middleware which adds geo-location IP blocking. +/// +/// Note: this middleware requires you to use +/// [Router::into_make_service_with_connect_info](https://docs.rs/axum/latest/axum/struct.Router.html#method.into_make_service_with_connect_info) +/// to run your app otherwise it will fail at runtime. +/// +/// See [Router::into_make_service_with_connect_info](https://docs.rs/axum/latest/axum/struct.Router.html#method.into_make_service_with_connect_info) for more details. pub use geoip; +#[cfg(feature = "tracing")] +use tracing::{error, info}; use { - crate::errors::GeoBlockError, - axum::extract::ConnectInfo, + axum::{extract::ConnectInfo, http::HeaderMap}, geoip::GeoIpResolver, http_body::Body, hyper::{Request, Response, StatusCode}, @@ -12,12 +18,45 @@ use { future::Future, net::{IpAddr, SocketAddr}, pin::Pin, + sync::Arc, task::{Context, Poll}, }, + thiserror::Error, tower::{Layer, Service}, }; -pub mod errors; +#[cfg(test)] +mod tests; + +/// Values used to configure the middleware behavior when country information +/// could not be retrieved. +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum MissingCountry { + Allow, + Block, +} + +#[derive(Debug, Error)] +pub enum GeoBlockError { + #[error("Country is blocked: {country}")] + BlockedCountry { country: Arc }, + + #[error("Unable to extract IP address")] + UnableToExtractIPAddress, + + #[error("Unable to extract geo data from IP address")] + UnableToExtractGeoData, + + #[error("Country could not be found in database")] + CountryNotFound, + + #[error("Other Error")] + Other { + code: StatusCode, + msg: Option, + headers: Option, + }, +} /// Layer that applies the GeoBlock middleware which blocks requests base on IP /// geo-location. @@ -27,7 +66,8 @@ pub struct GeoBlockLayer where T: GeoIpResolver, { - blocked_countries: Vec, + missing_country: MissingCountry, + blocked_countries: Vec>, geoip: T, } @@ -35,8 +75,13 @@ impl GeoBlockLayer where T: GeoIpResolver, { - pub fn new(geoip: T, blocked_countries: Vec) -> Self { + pub fn new( + geoip: T, + blocked_countries: Vec>, + missing_country: MissingCountry, + ) -> Self { Self { + missing_country, blocked_countries, geoip, } @@ -50,7 +95,12 @@ where type Service = GeoBlock; fn layer(&self, inner: S) -> Self::Service { - GeoBlock::new(inner, self.geoip.clone(), self.blocked_countries.clone()) + GeoBlock::new( + inner, + self.geoip.clone(), + self.blocked_countries.clone(), + self.missing_country, + ) } } @@ -60,7 +110,8 @@ where R: GeoIpResolver, { inner: S, - blocked_countries: Vec, + missing_country: MissingCountry, + blocked_countries: Vec>, geoip: R, } @@ -68,9 +119,15 @@ impl GeoBlock where R: GeoIpResolver, { - fn new(inner: S, geoip: R, blocked_countries: Vec) -> Self { + fn new( + inner: S, + geoip: R, + blocked_countries: Vec>, + missing_country: MissingCountry, + ) -> Self { Self { inner, + missing_country, blocked_countries, geoip, } @@ -84,21 +141,45 @@ where } fn check_caller(&self, caller: IpAddr) -> Result<(), GeoBlockError> { - let geo_data = self + let country = match self .geoip .lookup_geo_data(caller) - .map_err(|_| GeoBlockError::UnableToExtractGeoData)?; - - // TODO: let configure how to handle missing country - let country = geo_data - .country - .ok_or(GeoBlockError::CountryNotFound)? - .to_lowercase(); + .map_err(|_| GeoBlockError::UnableToExtractGeoData) + { + Ok(geo_data) => match geo_data.country { + None if self.missing_country == MissingCountry::Allow => { + #[cfg(feature = "tracing")] + { + info!("Country not found, but allowed"); + } + return Ok(()); + } + None => { + #[cfg(feature = "tracing")] + { + info!("Country not found"); + } + return Err(GeoBlockError::CountryNotFound); + } + Some(country) => country, + }, + Err(_e) => { + return if self.missing_country == MissingCountry::Allow { + Ok(()) + } else { + #[cfg(feature = "tracing")] + { + error!("Unable to extract geo data from IP address: {}", _e); + } + Err(GeoBlockError::UnableToExtractGeoData) + } + } + }; let is_blocked = self .blocked_countries .iter() - .any(|blocked_country| blocked_country == &country); + .any(|blocked_country| *blocked_country == country); if is_blocked { Err(GeoBlockError::BlockedCountry { country }) @@ -126,13 +207,26 @@ where match self.extract_ip(&req) { Ok(ip_addr) => match self.check_caller(ip_addr) { Ok(_) => ResponseFuture::future(self.inner.call(req)), - Err(_e) => { + Err(GeoBlockError::BlockedCountry { country: _country }) => { let mut res = Response::new(ResBody::default()); *res.status_mut() = StatusCode::UNAUTHORIZED; ResponseFuture::invalid_ip(res) } + Err(_e) => { + let mut res = Response::new(ResBody::default()); + *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; + ResponseFuture::invalid_ip(res) + } }, Err(_e) => { + if self.missing_country == MissingCountry::Allow { + #[cfg(feature = "tracing")] + { + error!("Unable to extract client IP address: {}", _e); + } + return ResponseFuture::future(self.inner.call(req)); + } + let mut res = Response::new(ResBody::default()); *res.status_mut() = StatusCode::UNAUTHORIZED; ResponseFuture::invalid_ip(res) diff --git a/crates/geoblock/src/tests.rs b/crates/geoblock/src/tests.rs new file mode 100644 index 0000000..09983bd --- /dev/null +++ b/crates/geoblock/src/tests.rs @@ -0,0 +1,67 @@ +use { + crate::geoip, + hyper::{Body, Request, Response, StatusCode}, + std::{ + convert::Infallible, + net::{IpAddr, Ipv4Addr}, + }, + tower::{Service, ServiceBuilder, ServiceExt}, +}; + +async fn handle(_request: Request) -> Result, Infallible> { + Ok(Response::new(Body::empty())) +} + +fn resolve_ip(caller: IpAddr) -> geoip::GeoData { + if IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)) == caller { + geoip::GeoData { + continent: Some("Asia".to_string().into()), + country: Some("Derkaderkastan".to_string().into()), + region: None, + city: None, + } + } else { + geoip::GeoData { + continent: Some("North America".to_string().into()), + country: Some("United States".to_string().into()), + region: None, + city: None, + } + } +} + +#[tokio::test] +async fn test_blocked_country() { + let resolver: geoip::local::LocalResolver = + geoip::local::LocalResolver::new(|caller| resolve_ip(caller)); + let blocked_countries = vec!["Derkaderkastan".into(), "Quran".into(), "Tristan".into()]; + + let geoblock = + crate::GeoBlockLayer::new(resolver, blocked_countries, crate::MissingCountry::Allow); + + let mut service = ServiceBuilder::new().layer(geoblock).service_fn(handle); + + let request = Request::builder().body(Body::empty()).unwrap(); + + let response = service.ready().await.unwrap().call(request).await.unwrap(); + + assert_eq!(response.status(), StatusCode::UNAUTHORIZED); +} + +#[tokio::test] +async fn test_non_blocked_country() { + let resolver: geoip::local::LocalResolver = + geoip::local::LocalResolver::new(|caller| resolve_ip(caller)); + let blocked_countries = vec!["Quran".into(), "Tristan".into()]; + + let geoblock = + crate::GeoBlockLayer::new(resolver, blocked_countries, crate::MissingCountry::Allow); + + let mut service = ServiceBuilder::new().layer(geoblock).service_fn(handle); + + let request = Request::builder().body(Body::empty()).unwrap(); + + let response = service.ready().await.unwrap().call(request).await.unwrap(); + + assert_eq!(response.status(), StatusCode::UNAUTHORIZED); +} diff --git a/examples/geoblock.rs b/examples/geoblock.rs index 5791301..6a098c1 100644 --- a/examples/geoblock.rs +++ b/examples/geoblock.rs @@ -1,5 +1,5 @@ use { - geoblock::geoip::GeoData, + geoblock::{geoip::GeoData, MissingCountry}, hyper::{Body, Request, Response, StatusCode}, std::{ convert::Infallible, @@ -34,13 +34,9 @@ fn resolve_ip(caller: IpAddr) -> GeoData { #[tokio::main] async fn main() -> Result<(), Box> { let resolver: LocalResolver = LocalResolver::new(|caller| resolve_ip(caller)); - let blocked_countries = vec![ - "Derkaderkastan".to_string(), - "Quran".to_string(), - "Tristan".to_string(), - ]; + let blocked_countries = vec!["Derkaderkastan".into(), "Quran".into(), "Tristan".into()]; - let geoblock = GeoBlockLayer::new(resolver, blocked_countries); + let geoblock = GeoBlockLayer::new(resolver, blocked_countries, MissingCountry::Allow); let mut service = ServiceBuilder::new().layer(geoblock).service_fn(handle); From de56d196c10fbd082ef7cff32de829db5a36e79f Mon Sep 17 00:00:00 2001 From: Xavier Basty Date: Tue, 26 Sep 2023 00:20:16 +0200 Subject: [PATCH 3/7] fix: simplify and fix PR comments --- crates/geoblock/Cargo.toml | 17 ++- crates/geoblock/src/lib.rs | 242 ++++++++++++++++------------------- crates/geoblock/src/tests.rs | 20 +-- examples/geoblock.rs | 13 +- 4 files changed, 138 insertions(+), 154 deletions(-) diff --git a/crates/geoblock/Cargo.toml b/crates/geoblock/Cargo.toml index 2ffaf43..6aee1cd 100644 --- a/crates/geoblock/Cargo.toml +++ b/crates/geoblock/Cargo.toml @@ -5,21 +5,20 @@ edition = "2021" [features] default = [] -full = ["tracing"] -tracing = ["dep:tracing"] [dependencies] +tracing = "0.1" axum = "0.6" thiserror = "1.0" hyper = "0.14" -tower = "0.4.13" -tower-layer = "0.3.2" -pin-project = "1.0.12" -futures-core = "0.3.28" -http-body = "0.4.5" +tower = "0.4" +tower-layer = "0.3" +pin-project = "1" +futures-core = "0.3" +http-body = "0.4" +axum-client-ip = "0.4" geoip = { git = "https://github.com/WalletConnect/geoip.git", tag = "0.2.0"} -tracing = { version = "0.1", optional = true } [dev-dependencies] tokio = { version = "1", features = ["full"] } -axum = "0.6.1" +axum = "0.6" diff --git a/crates/geoblock/src/lib.rs b/crates/geoblock/src/lib.rs index f517144..b9f92a7 100644 --- a/crates/geoblock/src/lib.rs +++ b/crates/geoblock/src/lib.rs @@ -1,28 +1,28 @@ -/// Middleware which adds geo-location IP blocking. -/// -/// Note: this middleware requires you to use -/// [Router::into_make_service_with_connect_info](https://docs.rs/axum/latest/axum/struct.Router.html#method.into_make_service_with_connect_info) -/// to run your app otherwise it will fail at runtime. -/// -/// See [Router::into_make_service_with_connect_info](https://docs.rs/axum/latest/axum/struct.Router.html#method.into_make_service_with_connect_info) for more details. +//! Middleware which adds geo-location IP blocking. +//! +//! Note: this middleware requires you to use +//! [Router::into_make_service_with_connect_info](https://docs.rs/axum/latest/axum/struct.Router.html#method.into_make_service_with_connect_info) +//! to run your app otherwise it will fail at runtime. +//! +//! See [Router::into_make_service_with_connect_info](https://docs.rs/axum/latest/axum/struct.Router.html#method.into_make_service_with_connect_info) for more details. pub use geoip; -#[cfg(feature = "tracing")] -use tracing::{error, info}; use { - axum::{extract::ConnectInfo, http::HeaderMap}, + axum_client_ip::InsecureClientIp, geoip::GeoIpResolver, http_body::Body, hyper::{Request, Response, StatusCode}, pin_project::pin_project, std::{ future::Future, - net::{IpAddr, SocketAddr}, + net::IpAddr, pin::Pin, sync::Arc, task::{Context, Poll}, }, thiserror::Error, - tower::{Layer, Service}, + tower::Service, + tower_layer::Layer, + tracing::{error, info}, }; #[cfg(test)] @@ -31,15 +31,17 @@ mod tests; /// Values used to configure the middleware behavior when country information /// could not be retrieved. #[derive(Debug, Clone, Copy, PartialEq)] -pub enum MissingCountry { - Allow, +pub enum BlockingPolicy { Block, + AllowMissingCountry, + AllowExtractFailure, + AllowAll, } #[derive(Debug, Error)] -pub enum GeoBlockError { - #[error("Country is blocked: {country}")] - BlockedCountry { country: Arc }, +enum GeoBlockError { + #[error("Country is blocked")] + Blocked, #[error("Unable to extract IP address")] UnableToExtractIPAddress, @@ -49,199 +51,186 @@ pub enum GeoBlockError { #[error("Country could not be found in database")] CountryNotFound, - - #[error("Other Error")] - Other { - code: StatusCode, - msg: Option, - headers: Option, - }, } /// Layer that applies the GeoBlock middleware which blocks requests base on IP /// geo-location. #[derive(Debug, Clone)] #[must_use] -pub struct GeoBlockLayer +pub struct GeoBlockLayer where - T: GeoIpResolver, + R: GeoIpResolver, { - missing_country: MissingCountry, blocked_countries: Vec>, - geoip: T, + ip_resolver: R, + blocking_policy: BlockingPolicy, } -impl GeoBlockLayer +impl GeoBlockLayer where - T: GeoIpResolver, + R: GeoIpResolver, { pub fn new( - geoip: T, + ip_resolver: R, blocked_countries: Vec>, - missing_country: MissingCountry, + blocking_policy: BlockingPolicy, ) -> Self { Self { - missing_country, + ip_resolver, blocked_countries, - geoip, + blocking_policy, } } } -impl Layer for GeoBlockLayer +impl Layer for GeoBlockLayer where - T: GeoIpResolver, + R: GeoIpResolver, { - type Service = GeoBlock; + type Service = GeoBlockService; fn layer(&self, inner: S) -> Self::Service { - GeoBlock::new( + GeoBlockService::new( inner, - self.geoip.clone(), + self.ip_resolver.clone(), self.blocked_countries.clone(), - self.missing_country, + self.blocking_policy, ) } } -#[derive(Clone, Debug)] -pub struct GeoBlock +/// Layer that applies the GeoBlock middleware which blocks requests base on IP +/// geo-location. +#[derive(Debug, Clone)] +#[must_use] +pub struct GeoBlockService where R: GeoIpResolver, { inner: S, - missing_country: MissingCountry, blocked_countries: Vec>, - geoip: R, + ip_resolver: R, + blocking_policy: BlockingPolicy, } -impl GeoBlock +impl GeoBlockService where R: GeoIpResolver, { - fn new( + pub fn new( inner: S, - geoip: R, + ip_resolver: R, blocked_countries: Vec>, - missing_country: MissingCountry, + blocking_policy: BlockingPolicy, ) -> Self { Self { inner, - missing_country, + blocking_policy, blocked_countries, - geoip, + ip_resolver, } } + /// Extracts the IP address from the request. fn extract_ip(&self, req: &Request) -> Result { - req.extensions() - .get::>() - .map(|ConnectInfo(addr)| addr.ip()) - .ok_or(GeoBlockError::UnableToExtractIPAddress) + let client_ip = InsecureClientIp::from(req.headers(), req.extensions()) + .map_err(|_| GeoBlockError::UnableToExtractIPAddress)?; + Ok(client_ip.0) } - fn check_caller(&self, caller: IpAddr) -> Result<(), GeoBlockError> { - let country = match self - .geoip + /// Checks if the specified IP address is allowed to access the service. + fn check_ip(&self, caller: IpAddr) -> Result<(), GeoBlockError> { + let country = self + .ip_resolver .lookup_geo_data(caller) - .map_err(|_| GeoBlockError::UnableToExtractGeoData) - { - Ok(geo_data) => match geo_data.country { - None if self.missing_country == MissingCountry::Allow => { - #[cfg(feature = "tracing")] - { - info!("Country not found, but allowed"); - } - return Ok(()); - } - None => { - #[cfg(feature = "tracing")] - { - info!("Country not found"); - } - return Err(GeoBlockError::CountryNotFound); - } - Some(country) => country, - }, - Err(_e) => { - return if self.missing_country == MissingCountry::Allow { - Ok(()) - } else { - #[cfg(feature = "tracing")] - { - error!("Unable to extract geo data from IP address: {}", _e); - } - Err(GeoBlockError::UnableToExtractGeoData) - } - } - }; + .map_err(|_| GeoBlockError::UnableToExtractGeoData)? + .country + .ok_or(GeoBlockError::CountryNotFound)?; - let is_blocked = self + if self .blocked_countries .iter() - .any(|blocked_country| *blocked_country == country); - - if is_blocked { - Err(GeoBlockError::BlockedCountry { country }) + .any(|blocked_country| *blocked_country == country) + { + Err(GeoBlockError::Blocked) } else { Ok(()) } } + + fn check_caller(&self, req: &Request) -> Result<(), GeoBlockError> { + self.check_ip(self.extract_ip(req)?) + } } -impl Service> for GeoBlock +impl Service> for GeoBlockService where - R: GeoIpResolver, S: Service, Response = Response>, + R: GeoIpResolver, ResBody: Body + Default, { type Error = S::Error; type Future = ResponseFuture; - type Response = Response; + type Response = S::Response; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.inner.poll_ready(cx) } - fn call(&mut self, req: Request) -> Self::Future { - match self.extract_ip(&req) { - Ok(ip_addr) => match self.check_caller(ip_addr) { - Ok(_) => ResponseFuture::future(self.inner.call(req)), - Err(GeoBlockError::BlockedCountry { country: _country }) => { - let mut res = Response::new(ResBody::default()); - *res.status_mut() = StatusCode::UNAUTHORIZED; - ResponseFuture::invalid_ip(res) - } - Err(_e) => { - let mut res = Response::new(ResBody::default()); - *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; - ResponseFuture::invalid_ip(res) + fn call(&mut self, request: Request) -> Self::Future { + match self.check_caller(&request) { + Ok(_) => ResponseFuture::future(self.inner.call(request)), + + Err(GeoBlockError::Blocked) => { + let mut res = Response::new(ResBody::default()); + *res.status_mut() = StatusCode::UNAUTHORIZED; + ResponseFuture::error(res) + } + + Err(GeoBlockError::UnableToExtractIPAddress) + | Err(GeoBlockError::UnableToExtractGeoData) => { + if self.blocking_policy == BlockingPolicy::AllowExtractFailure { + info!("Unable to extract client IP address"); + return ResponseFuture::future(self.inner.call(request)); } - }, - Err(_e) => { - if self.missing_country == MissingCountry::Allow { - #[cfg(feature = "tracing")] - { - error!("Unable to extract client IP address: {}", _e); - } - return ResponseFuture::future(self.inner.call(req)); + + let mut res = Response::new(ResBody::default()); + *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; + ResponseFuture::error(res) + } + + Err(e) => { + if self.blocking_policy == BlockingPolicy::AllowMissingCountry { + error!("Unable to extract client IP address: {}", e); + return ResponseFuture::future(self.inner.call(request)); } let mut res = Response::new(ResBody::default()); - *res.status_mut() = StatusCode::UNAUTHORIZED; - ResponseFuture::invalid_ip(res) + *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; + ResponseFuture::error(res) } } } } -#[pin_project] /// Response future for [`GeoBlock`]. +#[pin_project] pub struct ResponseFuture { #[pin] inner: Kind, } +#[pin_project(project = KindProj)] +enum Kind { + Future { + #[pin] + future: F, + }, + Error { + response: Option>, + }, +} + impl ResponseFuture { fn future(future: F) -> Self { Self { @@ -249,7 +238,7 @@ impl ResponseFuture { } } - fn invalid_ip(res: Response) -> Self { + fn error(res: Response) -> Self { Self { inner: Kind::Error { response: Some(res), @@ -258,17 +247,6 @@ impl ResponseFuture { } } -#[pin_project(project = KindProj)] -enum Kind { - Future { - #[pin] - future: F, - }, - Error { - response: Option>, - }, -} - impl Future for ResponseFuture where F: Future, E>>, diff --git a/crates/geoblock/src/tests.rs b/crates/geoblock/src/tests.rs index 09983bd..1127f58 100644 --- a/crates/geoblock/src/tests.rs +++ b/crates/geoblock/src/tests.rs @@ -1,5 +1,5 @@ use { - crate::geoip, + crate::{geoip, BlockingPolicy, GeoBlockLayer}, hyper::{Body, Request, Response, StatusCode}, std::{ convert::Infallible, @@ -36,12 +36,14 @@ async fn test_blocked_country() { geoip::local::LocalResolver::new(|caller| resolve_ip(caller)); let blocked_countries = vec!["Derkaderkastan".into(), "Quran".into(), "Tristan".into()]; - let geoblock = - crate::GeoBlockLayer::new(resolver, blocked_countries, crate::MissingCountry::Allow); + let geoblock = GeoBlockLayer::new(resolver, blocked_countries, BlockingPolicy::Block); let mut service = ServiceBuilder::new().layer(geoblock).service_fn(handle); - let request = Request::builder().body(Body::empty()).unwrap(); + let request = Request::builder() + .header("X-Forwarded-For", "127.0.0.1") + .body(Body::empty()) + .unwrap(); let response = service.ready().await.unwrap().call(request).await.unwrap(); @@ -54,14 +56,16 @@ async fn test_non_blocked_country() { geoip::local::LocalResolver::new(|caller| resolve_ip(caller)); let blocked_countries = vec!["Quran".into(), "Tristan".into()]; - let geoblock = - crate::GeoBlockLayer::new(resolver, blocked_countries, crate::MissingCountry::Allow); + let geoblock = GeoBlockLayer::new(resolver, blocked_countries, BlockingPolicy::Block); let mut service = ServiceBuilder::new().layer(geoblock).service_fn(handle); - let request = Request::builder().body(Body::empty()).unwrap(); + let request = Request::builder() + .header("X-Forwarded-For", "127.0.0.1") + .body(Body::empty()) + .unwrap(); let response = service.ready().await.unwrap().call(request).await.unwrap(); - assert_eq!(response.status(), StatusCode::UNAUTHORIZED); + assert_eq!(response.status(), StatusCode::OK); } diff --git a/examples/geoblock.rs b/examples/geoblock.rs index 6a098c1..9ee0d31 100644 --- a/examples/geoblock.rs +++ b/examples/geoblock.rs @@ -1,12 +1,12 @@ use { - geoblock::{geoip::GeoData, MissingCountry}, + geoblock::{geoip::GeoData, BlockingPolicy, GeoBlockLayer}, hyper::{Body, Request, Response, StatusCode}, std::{ convert::Infallible, net::{IpAddr, Ipv4Addr}, }, tower::{Service, ServiceBuilder, ServiceExt}, - wc::geoblock::{geoip::local::LocalResolver, GeoBlockLayer}, + wc::geoblock::geoip::local::LocalResolver, }; async fn handle(_request: Request) -> Result, Infallible> { @@ -36,13 +36,16 @@ async fn main() -> Result<(), Box> { let resolver: LocalResolver = LocalResolver::new(|caller| resolve_ip(caller)); let blocked_countries = vec!["Derkaderkastan".into(), "Quran".into(), "Tristan".into()]; - let geoblock = GeoBlockLayer::new(resolver, blocked_countries, MissingCountry::Allow); + let geoblock = GeoBlockLayer::new(resolver, blocked_countries, BlockingPolicy::Block); let mut service = ServiceBuilder::new().layer(geoblock).service_fn(handle); - let request = Request::builder().body(Body::empty()).unwrap(); + let request = Request::builder() + .header("X-Forwarded-For", "127.0.0.1") + .body(Body::empty()) + .unwrap(); - let response = service.ready().await?.call(request).await?; + let response = service.ready().await.unwrap().call(request).await.unwrap(); assert_eq!(response.status(), StatusCode::UNAUTHORIZED); From 15bd4dc0cc40dd9a22413f0b62e8042b59622bcb Mon Sep 17 00:00:00 2001 From: Xavier Basty Date: Tue, 26 Sep 2023 09:23:59 +0200 Subject: [PATCH 4/7] fix: clippy warnings --- crates/geoblock/src/tests.rs | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/crates/geoblock/src/tests.rs b/crates/geoblock/src/tests.rs index 1127f58..ed3e837 100644 --- a/crates/geoblock/src/tests.rs +++ b/crates/geoblock/src/tests.rs @@ -32,8 +32,7 @@ fn resolve_ip(caller: IpAddr) -> geoip::GeoData { #[tokio::test] async fn test_blocked_country() { - let resolver: geoip::local::LocalResolver = - geoip::local::LocalResolver::new(|caller| resolve_ip(caller)); + let resolver: geoip::local::LocalResolver = geoip::local::LocalResolver::new(resolve_ip); let blocked_countries = vec!["Derkaderkastan".into(), "Quran".into(), "Tristan".into()]; let geoblock = GeoBlockLayer::new(resolver, blocked_countries, BlockingPolicy::Block); @@ -52,8 +51,7 @@ async fn test_blocked_country() { #[tokio::test] async fn test_non_blocked_country() { - let resolver: geoip::local::LocalResolver = - geoip::local::LocalResolver::new(|caller| resolve_ip(caller)); + let resolver: geoip::local::LocalResolver = geoip::local::LocalResolver::new(resolve_ip); let blocked_countries = vec!["Quran".into(), "Tristan".into()]; let geoblock = GeoBlockLayer::new(resolver, blocked_countries, BlockingPolicy::Block); From 835dbfd43938692655728622f532de308f8bca3f Mon Sep 17 00:00:00 2001 From: Xavier Basty Date: Tue, 26 Sep 2023 12:57:54 +0200 Subject: [PATCH 5/7] fix: replace `Arc` with `String` for blocked countries --- crates/geoblock/src/lib.rs | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/crates/geoblock/src/lib.rs b/crates/geoblock/src/lib.rs index b9f92a7..f9caafa 100644 --- a/crates/geoblock/src/lib.rs +++ b/crates/geoblock/src/lib.rs @@ -16,7 +16,6 @@ use { future::Future, net::IpAddr, pin::Pin, - sync::Arc, task::{Context, Poll}, }, thiserror::Error, @@ -61,7 +60,7 @@ pub struct GeoBlockLayer where R: GeoIpResolver, { - blocked_countries: Vec>, + blocked_countries: Vec, ip_resolver: R, blocking_policy: BlockingPolicy, } @@ -72,7 +71,7 @@ where { pub fn new( ip_resolver: R, - blocked_countries: Vec>, + blocked_countries: Vec, blocking_policy: BlockingPolicy, ) -> Self { Self { @@ -108,7 +107,7 @@ where R: GeoIpResolver, { inner: S, - blocked_countries: Vec>, + blocked_countries: Vec, ip_resolver: R, blocking_policy: BlockingPolicy, } @@ -120,7 +119,7 @@ where pub fn new( inner: S, ip_resolver: R, - blocked_countries: Vec>, + blocked_countries: Vec, blocking_policy: BlockingPolicy, ) -> Self { Self { @@ -150,7 +149,7 @@ where if self .blocked_countries .iter() - .any(|blocked_country| *blocked_country == country) + .any(|blocked_country| *blocked_country == *country) { Err(GeoBlockError::Blocked) } else { From 04108a70599c8f82412a0406046108997951c2a0 Mon Sep 17 00:00:00 2001 From: Xavier Basty Date: Tue, 26 Sep 2023 13:03:12 +0200 Subject: [PATCH 6/7] fix: use ISO codes in example --- examples/geoblock.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/geoblock.rs b/examples/geoblock.rs index 9ee0d31..1b833b1 100644 --- a/examples/geoblock.rs +++ b/examples/geoblock.rs @@ -16,15 +16,15 @@ async fn handle(_request: Request) -> Result, Infallible> { fn resolve_ip(caller: IpAddr) -> GeoData { if IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)) == caller { GeoData { - continent: Some("Asia".to_string().into()), - country: Some("Derkaderkastan".to_string().into()), + continent: Some("NA".to_string().into()), + country: Some("CU".to_string().into()), region: None, city: None, } } else { GeoData { - continent: Some("North America".to_string().into()), - country: Some("United States".to_string().into()), + continent: Some("NA".to_string().into()), + country: Some("US".to_string().into()), region: None, city: None, } @@ -34,7 +34,7 @@ fn resolve_ip(caller: IpAddr) -> GeoData { #[tokio::main] async fn main() -> Result<(), Box> { let resolver: LocalResolver = LocalResolver::new(|caller| resolve_ip(caller)); - let blocked_countries = vec!["Derkaderkastan".into(), "Quran".into(), "Tristan".into()]; + let blocked_countries = vec!["CU".into(), "IR".into(), "KP".into()]; let geoblock = GeoBlockLayer::new(resolver, blocked_countries, BlockingPolicy::Block); From 9bd0b9a9932978f5886c65cff13ec1f91bf51a7c Mon Sep 17 00:00:00 2001 From: Xavier Basty Date: Tue, 26 Sep 2023 13:14:46 +0200 Subject: [PATCH 7/7] fix: use ISO codes in tests --- crates/geoblock/src/tests.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/crates/geoblock/src/tests.rs b/crates/geoblock/src/tests.rs index ed3e837..51d9a78 100644 --- a/crates/geoblock/src/tests.rs +++ b/crates/geoblock/src/tests.rs @@ -15,15 +15,15 @@ async fn handle(_request: Request) -> Result, Infallible> { fn resolve_ip(caller: IpAddr) -> geoip::GeoData { if IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)) == caller { geoip::GeoData { - continent: Some("Asia".to_string().into()), - country: Some("Derkaderkastan".to_string().into()), + continent: Some("NA".to_string().into()), + country: Some("CU".to_string().into()), region: None, city: None, } } else { geoip::GeoData { - continent: Some("North America".to_string().into()), - country: Some("United States".to_string().into()), + continent: Some("NA".to_string().into()), + country: Some("US".to_string().into()), region: None, city: None, } @@ -33,7 +33,7 @@ fn resolve_ip(caller: IpAddr) -> geoip::GeoData { #[tokio::test] async fn test_blocked_country() { let resolver: geoip::local::LocalResolver = geoip::local::LocalResolver::new(resolve_ip); - let blocked_countries = vec!["Derkaderkastan".into(), "Quran".into(), "Tristan".into()]; + let blocked_countries = vec!["CU".into(), "IR".into(), "KP".into()]; let geoblock = GeoBlockLayer::new(resolver, blocked_countries, BlockingPolicy::Block); @@ -52,7 +52,7 @@ async fn test_blocked_country() { #[tokio::test] async fn test_non_blocked_country() { let resolver: geoip::local::LocalResolver = geoip::local::LocalResolver::new(resolve_ip); - let blocked_countries = vec!["Quran".into(), "Tristan".into()]; + let blocked_countries = vec!["IR".into(), "KP".into()]; let geoblock = GeoBlockLayer::new(resolver, blocked_countries, BlockingPolicy::Block);