From cf8118f68a456a91c9549fa34225e37999b2aa43 Mon Sep 17 00:00:00 2001 From: Xavier Basty Date: Fri, 3 Nov 2023 09:28:44 +0100 Subject: [PATCH 1/2] chore: enable finer granularity for geo-blocking --- crates/geoip/src/block.rs | 55 +++++-- crates/geoip/src/block/middleware.rs | 10 +- crates/geoip/src/block/middleware/tests.rs | 161 +++++++++++++++++++-- examples/geoblock.rs | 6 +- 4 files changed, 202 insertions(+), 30 deletions(-) diff --git a/crates/geoip/src/block.rs b/crates/geoip/src/block.rs index bd0622b..ce11f31 100644 --- a/crates/geoip/src/block.rs +++ b/crates/geoip/src/block.rs @@ -8,7 +8,7 @@ bitflags! { #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct BlockingPolicy: u8 { const Block = 0b00; - const AllowMissingCountry = 0b01; + const AllowMissingGeoData = 0b01; const AllowExtractFailure = 0b10; const AllowAll = 0b11; } @@ -30,15 +30,15 @@ pub enum Error { } #[derive(Debug, Clone)] -pub struct CountryFilter { - blocked_countries: Vec, +pub struct ZoneFilter { + blocked_zones: Vec, blocking_policy: BlockingPolicy, } -impl CountryFilter { +impl ZoneFilter { pub fn new(blocked_countries: Vec, blocking_policy: BlockingPolicy) -> Self { Self { - blocked_countries, + blocked_zones: blocked_countries, blocking_policy, } } @@ -49,19 +49,48 @@ impl CountryFilter { where R: Resolver, { - let country = resolver + let geo_data = resolver .lookup_geo_data_raw(addr) - .map_err(|_| Error::UnableToExtractGeoData)? + .map_err(|_| Error::UnableToExtractGeoData)?; + + let country = geo_data .country .and_then(|country| country.iso_code) .ok_or(Error::CountryNotFound)?; - let blocked = self - .blocked_countries - .iter() - .any(|blocked_country| blocked_country == country); + let subdivisions: Vec<&str> = geo_data + .subdivisions + .map(|subdivisions| { + subdivisions + .into_iter() + .filter_map(|sub| sub.iso_code) + .collect::>() + }) + .unwrap_or_default(); + + let zone_blocked = self.blocked_zones.iter().any(|blocked_zone| { + blocked_zone + .split(':') + .collect::>() + .split_first() + .map_or(false, |(blocked_country, blocked_subdivisions)| { + if blocked_country == &country { + if blocked_subdivisions.is_empty() { + true + } else { + subdivisions.iter().any(|sub| { + blocked_subdivisions + .iter() + .any(|blocked_sub| sub.eq_ignore_ascii_case(blocked_sub)) + }) + } + } else { + false + } + }) + }); - if blocked { + if zone_blocked { Err(Error::Blocked) } else { Ok(()) @@ -75,7 +104,7 @@ impl CountryFilter { let policy = self.blocking_policy; let is_blocked = matches!(err, Error::UnableToExtractIPAddress | Error::UnableToExtractGeoData if !policy.contains(BlockingPolicy::AllowExtractFailure)) - || matches!(err, Error::CountryNotFound if !policy.contains(BlockingPolicy::AllowMissingCountry)) + || matches!(err, Error::CountryNotFound if !policy.contains(BlockingPolicy::AllowMissingGeoData)) || matches!(err, Error::Blocked); if is_blocked { diff --git a/crates/geoip/src/block/middleware.rs b/crates/geoip/src/block/middleware.rs index 0a82e95..9c1a9e6 100644 --- a/crates/geoip/src/block/middleware.rs +++ b/crates/geoip/src/block/middleware.rs @@ -7,7 +7,7 @@ //! See [Router::into_make_service_with_connect_info](https://docs.rs/axum/latest/axum/struct.Router.html#method.into_make_service_with_connect_info) for more details. use { - super::{BlockingPolicy, CountryFilter, Error}, + super::{BlockingPolicy, Error, ZoneFilter}, crate::Resolver, axum_client_ip::InsecureClientIp, futures::future::{self, Either, Ready}, @@ -26,7 +26,7 @@ mod tests; #[derive(Debug)] struct Inner { - filter: CountryFilter, + filter: ZoneFilter, ip_resolver: R, } @@ -52,7 +52,7 @@ where ) -> Self { Self { inner: Arc::new(Inner { - filter: CountryFilter::new(blocked_countries, blocking_policy), + filter: ZoneFilter::new(blocked_countries, blocking_policy), ip_resolver, }), } @@ -92,13 +92,13 @@ where pub fn new( service: S, ip_resolver: R, - blocked_countries: Vec, + blocked_zones: Vec, blocking_policy: BlockingPolicy, ) -> Self { Self { service, inner: Arc::new(Inner { - filter: CountryFilter::new(blocked_countries, blocking_policy), + filter: ZoneFilter::new(blocked_zones, blocking_policy), ip_resolver, }), } diff --git a/crates/geoip/src/block/middleware/tests.rs b/crates/geoip/src/block/middleware/tests.rs index 389e6b4..e0d9a9b 100644 --- a/crates/geoip/src/block/middleware/tests.rs +++ b/crates/geoip/src/block/middleware/tests.rs @@ -13,7 +13,7 @@ async fn handle(_request: Request) -> Result, Infallible> { Ok(Response::new(Body::empty())) } -fn resolve_ip(_addr: IpAddr) -> City<'static> { +fn resolve_ip_no_subs(_addr: IpAddr) -> City<'static> { City { city: None, continent: None, @@ -32,8 +32,40 @@ fn resolve_ip(_addr: IpAddr) -> City<'static> { } } +fn resolve_ip(_addr: IpAddr) -> City<'static> { + City { + city: None, + continent: None, + country: Some(geoip2::city::Country { + geoname_id: None, + is_in_european_union: None, + iso_code: Some("CU"), + names: None, + }), + location: None, + postal: None, + registered_country: None, + represented_country: None, + subdivisions: Some(vec![ + geoip2::city::Subdivision { + geoname_id: None, + iso_code: Some("12"), + names: None, + }, + geoip2::city::Subdivision { + geoname_id: None, + iso_code: Some("34"), + names: None, + }, + ]), + traits: None, + } +} + +/// Test that a blocking list with no subdivisions blocks the country if +/// a match is found. #[tokio::test] -async fn test_blocked_country() { +async fn test_country_blocked() { let resolver = LocalResolver::new(Some(resolve_ip), None); let blocked_countries = vec!["CU".into(), "IR".into(), "KP".into()]; @@ -51,10 +83,75 @@ async fn test_blocked_country() { assert_eq!(response.status(), StatusCode::UNAUTHORIZED); } +/// Test that a blocking list with no subdivisions doesn't block if the +/// country doesn't match. #[tokio::test] -async fn test_reference() { +async fn test_country_non_blocked() { let resolver = LocalResolver::new(Some(resolve_ip), None); - let blocked_countries = vec!["CU".into(), "IR".into(), "KP".into()]; + let blocked_countries = vec!["IR".into(), "KP".into()]; + + let geoblock = GeoBlockLayer::new(resolver, blocked_countries, BlockingPolicy::Block); + + let mut service = ServiceBuilder::new().layer(geoblock).service_fn(handle); + + let request = Request::builder() + .header("X-Forwarded-For", "127.0.0.1") + .body(Body::empty()) + .unwrap(); + + let response = service.ready().await.unwrap().call(request).await.unwrap(); + + assert_eq!(response.status(), StatusCode::OK); +} + +/// Test that a blocking list with subdivisions doesn't block if the +/// subdivisions don't match, even if the country matches. +#[tokio::test] +async fn test_sub_unblocked_wrong_sub() { + let resolver = LocalResolver::new(Some(resolve_ip), None); + let blocked_countries = vec!["CU:56".into(), "IR".into(), "KP".into()]; + + let geoblock = GeoBlockLayer::new(&resolver, blocked_countries, BlockingPolicy::Block); + + let mut service = ServiceBuilder::new().layer(geoblock).service_fn(handle); + + let request = Request::builder() + .header("X-Forwarded-For", "127.0.0.1") + .body(Body::empty()) + .unwrap(); + + let response = service.ready().await.unwrap().call(request).await.unwrap(); + + assert_eq!(response.status(), StatusCode::OK); +} + +/// Test that a blocking list with subdivisions doesn't block if the country +/// doesn't match, even if subdivisions match. +#[tokio::test] +async fn test_sub_unblocked_wrong_country() { + let resolver = LocalResolver::new(Some(resolve_ip), None); + let blocked_countries = vec!["IR:12".into(), "KP".into()]; + + let geoblock = GeoBlockLayer::new(&resolver, blocked_countries, BlockingPolicy::Block); + + let mut service = ServiceBuilder::new().layer(geoblock).service_fn(handle); + + let request = Request::builder() + .header("X-Forwarded-For", "127.0.0.1") + .body(Body::empty()) + .unwrap(); + + let response = service.ready().await.unwrap().call(request).await.unwrap(); + + assert_eq!(response.status(), StatusCode::OK); +} + +/// Test that a blocking list with subdivisions blocks containing only one +/// subdivision blocks if the country and subdivision match. +#[tokio::test] +async fn test_sub_blocked_country_sub() { + let resolver = LocalResolver::new(Some(resolve_ip), None); + let blocked_countries = vec!["CU:12".into(), "IR".into(), "KP".into()]; let geoblock = GeoBlockLayer::new(&resolver, blocked_countries, BlockingPolicy::Block); @@ -70,10 +167,12 @@ async fn test_reference() { assert_eq!(response.status(), StatusCode::UNAUTHORIZED); } +/// Test that a blocking list with subdivisions blocks containing several +/// subdivisions blocks if the country and subdivision match. #[tokio::test] -async fn test_arc() { - let resolver = Arc::from(LocalResolver::new(Some(resolve_ip), None)); - let blocked_countries = vec!["CU".into(), "IR".into(), "KP".into()]; +async fn test_subs_blocked_country_sub() { + let resolver = LocalResolver::new(Some(resolve_ip), None); + let blocked_countries = vec!["CU:12".into(), "CU:34".into(), "IR".into(), "KP".into()]; let geoblock = GeoBlockLayer::new(&resolver, blocked_countries, BlockingPolicy::Block); @@ -89,10 +188,33 @@ async fn test_arc() { assert_eq!(response.status(), StatusCode::UNAUTHORIZED); } +/// Test that a blocking list with subdivisions blocks containing several +/// subdivisions in short form blocks if the country and subdivision match. #[tokio::test] -async fn test_non_blocked_country() { +async fn test_short_subs_blocked_country_sub() { let resolver = LocalResolver::new(Some(resolve_ip), None); - let blocked_countries = vec!["IR".into(), "KP".into()]; + let blocked_countries = vec!["CU:12:34".into(), "IR".into(), "KP".into()]; + + let geoblock = GeoBlockLayer::new(&resolver, blocked_countries, BlockingPolicy::Block); + + let mut service = ServiceBuilder::new().layer(geoblock).service_fn(handle); + + let request = Request::builder() + .header("X-Forwarded-For", "127.0.0.1") + .body(Body::empty()) + .unwrap(); + + let response = service.ready().await.unwrap().call(request).await.unwrap(); + + assert_eq!(response.status(), StatusCode::UNAUTHORIZED); +} + +/// Test that the blocker doesn't crash if the GeoIP resolution doesn't contain +/// any subdivisions. +#[tokio::test] +async fn test_unresolved_subdivisions() { + let resolver = LocalResolver::new(Some(resolve_ip_no_subs), None); + let blocked_countries = vec!["CU".into(), "IR".into(), "KP".into()]; let geoblock = GeoBlockLayer::new(resolver, blocked_countries, BlockingPolicy::Block); @@ -105,5 +227,24 @@ async fn test_non_blocked_country() { let response = service.ready().await.unwrap().call(request).await.unwrap(); - assert_eq!(response.status(), StatusCode::OK); + assert_eq!(response.status(), StatusCode::UNAUTHORIZED); +} + +#[tokio::test] +async fn test_arc() { + let resolver = Arc::from(LocalResolver::new(Some(resolve_ip), None)); + let blocked_countries = vec!["CU".into(), "IR".into(), "KP".into()]; + + let geoblock = GeoBlockLayer::new(&resolver, blocked_countries, BlockingPolicy::Block); + + let mut service = ServiceBuilder::new().layer(geoblock).service_fn(handle); + + let request = Request::builder() + .header("X-Forwarded-For", "127.0.0.1") + .body(Body::empty()) + .unwrap(); + + let response = service.ready().await.unwrap().call(request).await.unwrap(); + + assert_eq!(response.status(), StatusCode::UNAUTHORIZED); } diff --git a/examples/geoblock.rs b/examples/geoblock.rs index 02ae781..167c87d 100644 --- a/examples/geoblock.rs +++ b/examples/geoblock.rs @@ -20,7 +20,7 @@ fn resolve_ip(_addr: IpAddr) -> geoip2::City<'static> { country: Some(geoip2::city::Country { geoname_id: None, is_in_european_union: None, - iso_code: Some("CU"), + iso_code: Some("IR"), names: None, }), location: None, @@ -35,7 +35,9 @@ fn resolve_ip(_addr: IpAddr) -> geoip2::City<'static> { #[tokio::main] async fn main() -> Result<(), Box> { let resolver: LocalResolver = LocalResolver::new(Some(|caller| resolve_ip(caller)), None); - let blocked_countries = vec!["CU".into(), "IR".into(), "KP".into()]; + // The number after the colon is the ISO code of the subdivision when you don't + // want to block the whole country. + let blocked_countries = vec!["CU:12".into(), "IR".into(), "KP".into()]; let geoblock = GeoBlockLayer::new(resolver, blocked_countries, BlockingPolicy::Block); From 2d4d2a7762c5f681ddd7e76db4d67e4b183ef374 Mon Sep 17 00:00:00 2001 From: Xavier Basty Date: Fri, 3 Nov 2023 11:17:53 +0100 Subject: [PATCH 2/2] fix: preprocessing and optimize allocations --- crates/geoip/src/block.rs | 76 ++++++++++++++++++++++----------------- 1 file changed, 44 insertions(+), 32 deletions(-) diff --git a/crates/geoip/src/block.rs b/crates/geoip/src/block.rs index ce11f31..63e7a8b 100644 --- a/crates/geoip/src/block.rs +++ b/crates/geoip/src/block.rs @@ -29,16 +29,35 @@ pub enum Error { CountryNotFound, } +#[derive(Debug, Clone)] +struct Zone { + country: String, + subdivisions: Vec, +} + #[derive(Debug, Clone)] pub struct ZoneFilter { - blocked_zones: Vec, + blocked_zones: Vec, blocking_policy: BlockingPolicy, } impl ZoneFilter { - pub fn new(blocked_countries: Vec, blocking_policy: BlockingPolicy) -> Self { + pub fn new(blocked_zones: Vec, blocking_policy: BlockingPolicy) -> Self { + let blocked_zones = blocked_zones + .iter() + .filter_map(|zone| { + zone.split(':') + .collect::>() + .split_first() + .map(|(country, subdivisions)| Zone { + country: country.to_string(), + subdivisions: subdivisions.iter().map(|&s| s.to_string()).collect(), + }) + }) + .collect::>(); + Self { - blocked_zones: blocked_countries, + blocked_zones, blocking_policy, } } @@ -58,36 +77,29 @@ impl ZoneFilter { .and_then(|country| country.iso_code) .ok_or(Error::CountryNotFound)?; - let subdivisions: Vec<&str> = geo_data - .subdivisions - .map(|subdivisions| { - subdivisions - .into_iter() - .filter_map(|sub| sub.iso_code) - .collect::>() - }) - .unwrap_or_default(); - let zone_blocked = self.blocked_zones.iter().any(|blocked_zone| { - blocked_zone - .split(':') - .collect::>() - .split_first() - .map_or(false, |(blocked_country, blocked_subdivisions)| { - if blocked_country == &country { - if blocked_subdivisions.is_empty() { - true - } else { - subdivisions.iter().any(|sub| { - blocked_subdivisions - .iter() - .any(|blocked_sub| sub.eq_ignore_ascii_case(blocked_sub)) - }) - } - } else { - false - } - }) + if blocked_zone.country == country { + if blocked_zone.subdivisions.is_empty() { + true + } else { + geo_data + .subdivisions + .as_deref() + .map_or(false, |subdivisions| { + subdivisions + .iter() + .filter_map(|sub| sub.iso_code) + .any(|sub| { + blocked_zone + .subdivisions + .iter() + .any(|blocked_sub| sub.eq_ignore_ascii_case(blocked_sub)) + }) + }) + } + } else { + false + } }); if zone_blocked {