Skip to content

Commit

Permalink
feat: enable finer granularity for geo-blocking (#10)
Browse files Browse the repository at this point in the history
  • Loading branch information
xav authored Nov 3, 2023
1 parent 3f6fcca commit 5f636b8
Show file tree
Hide file tree
Showing 4 changed files with 215 additions and 31 deletions.
69 changes: 55 additions & 14 deletions crates/geoip/src/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ bitflags! {
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct BlockingPolicy: u8 {
const Block = 0b00;
const AllowMissingCountry = 0b01;
const AllowMissingGeoData = 0b01;
const AllowExtractFailure = 0b10;
const AllowAll = 0b11;
}
Expand All @@ -30,15 +30,34 @@ pub enum Error {
}

#[derive(Debug, Clone)]
pub struct CountryFilter {
blocked_countries: Vec<String>,
struct Zone {
country: String,
subdivisions: Vec<String>,
}

#[derive(Debug, Clone)]
pub struct ZoneFilter {
blocked_zones: Vec<Zone>,
blocking_policy: BlockingPolicy,
}

impl CountryFilter {
pub fn new(blocked_countries: Vec<String>, blocking_policy: BlockingPolicy) -> Self {
impl ZoneFilter {
pub fn new(blocked_zones: Vec<String>, blocking_policy: BlockingPolicy) -> Self {
let blocked_zones = blocked_zones
.iter()
.filter_map(|zone| {
zone.split(':')
.collect::<Vec<_>>()
.split_first()
.map(|(country, subdivisions)| Zone {
country: country.to_string(),
subdivisions: subdivisions.iter().map(|&s| s.to_string()).collect(),
})
})
.collect::<Vec<_>>();

Self {
blocked_countries,
blocked_zones,
blocking_policy,
}
}
Expand All @@ -49,19 +68,41 @@ impl CountryFilter {
where
R: Resolver,
{
let country = resolver
let geo_data = resolver
.lookup_geo_data_raw(addr)
.map_err(|_| Error::UnableToExtractGeoData)?
.map_err(|_| Error::UnableToExtractGeoData)?;

let country = geo_data
.country
.and_then(|country| country.iso_code)
.ok_or(Error::CountryNotFound)?;

let blocked = self
.blocked_countries
.iter()
.any(|blocked_country| blocked_country == country);
let zone_blocked = self.blocked_zones.iter().any(|blocked_zone| {
if blocked_zone.country == country {
if blocked_zone.subdivisions.is_empty() {
true
} else {
geo_data
.subdivisions
.as_deref()
.map_or(false, |subdivisions| {
subdivisions
.iter()
.filter_map(|sub| sub.iso_code)
.any(|sub| {
blocked_zone
.subdivisions
.iter()
.any(|blocked_sub| sub.eq_ignore_ascii_case(blocked_sub))
})
})
}
} else {
false
}
});

if blocked {
if zone_blocked {
Err(Error::Blocked)
} else {
Ok(())
Expand All @@ -75,7 +116,7 @@ impl CountryFilter {
let policy = self.blocking_policy;

let is_blocked = matches!(err, Error::UnableToExtractIPAddress | Error::UnableToExtractGeoData if !policy.contains(BlockingPolicy::AllowExtractFailure))
|| matches!(err, Error::CountryNotFound if !policy.contains(BlockingPolicy::AllowMissingCountry))
|| matches!(err, Error::CountryNotFound if !policy.contains(BlockingPolicy::AllowMissingGeoData))
|| matches!(err, Error::Blocked);

if is_blocked {
Expand Down
10 changes: 5 additions & 5 deletions crates/geoip/src/block/middleware.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
//! See [Router::into_make_service_with_connect_info](https://docs.rs/axum/latest/axum/struct.Router.html#method.into_make_service_with_connect_info) for more details.
use {
super::{BlockingPolicy, CountryFilter, Error},
super::{BlockingPolicy, Error, ZoneFilter},
crate::Resolver,
axum_client_ip::InsecureClientIp,
futures::future::{self, Either, Ready},
Expand All @@ -26,7 +26,7 @@ mod tests;

#[derive(Debug)]
struct Inner<R> {
filter: CountryFilter,
filter: ZoneFilter,
ip_resolver: R,
}

Expand All @@ -52,7 +52,7 @@ where
) -> Self {
Self {
inner: Arc::new(Inner {
filter: CountryFilter::new(blocked_countries, blocking_policy),
filter: ZoneFilter::new(blocked_countries, blocking_policy),
ip_resolver,
}),
}
Expand Down Expand Up @@ -92,13 +92,13 @@ where
pub fn new(
service: S,
ip_resolver: R,
blocked_countries: Vec<String>,
blocked_zones: Vec<String>,
blocking_policy: BlockingPolicy,
) -> Self {
Self {
service,
inner: Arc::new(Inner {
filter: CountryFilter::new(blocked_countries, blocking_policy),
filter: ZoneFilter::new(blocked_zones, blocking_policy),
ip_resolver,
}),
}
Expand Down
161 changes: 151 additions & 10 deletions crates/geoip/src/block/middleware/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ async fn handle(_request: Request<Body>) -> Result<Response<Body>, Infallible> {
Ok(Response::new(Body::empty()))
}

fn resolve_ip(_addr: IpAddr) -> City<'static> {
fn resolve_ip_no_subs(_addr: IpAddr) -> City<'static> {
City {
city: None,
continent: None,
Expand All @@ -32,8 +32,40 @@ fn resolve_ip(_addr: IpAddr) -> City<'static> {
}
}

fn resolve_ip(_addr: IpAddr) -> City<'static> {
City {
city: None,
continent: None,
country: Some(geoip2::city::Country {
geoname_id: None,
is_in_european_union: None,
iso_code: Some("CU"),
names: None,
}),
location: None,
postal: None,
registered_country: None,
represented_country: None,
subdivisions: Some(vec![
geoip2::city::Subdivision {
geoname_id: None,
iso_code: Some("12"),
names: None,
},
geoip2::city::Subdivision {
geoname_id: None,
iso_code: Some("34"),
names: None,
},
]),
traits: None,
}
}

/// Test that a blocking list with no subdivisions blocks the country if
/// a match is found.
#[tokio::test]
async fn test_blocked_country() {
async fn test_country_blocked() {
let resolver = LocalResolver::new(Some(resolve_ip), None);
let blocked_countries = vec!["CU".into(), "IR".into(), "KP".into()];

Expand All @@ -51,10 +83,75 @@ async fn test_blocked_country() {
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}

/// Test that a blocking list with no subdivisions doesn't block if the
/// country doesn't match.
#[tokio::test]
async fn test_reference() {
async fn test_country_non_blocked() {
let resolver = LocalResolver::new(Some(resolve_ip), None);
let blocked_countries = vec!["CU".into(), "IR".into(), "KP".into()];
let blocked_countries = vec!["IR".into(), "KP".into()];

let geoblock = GeoBlockLayer::new(resolver, blocked_countries, BlockingPolicy::Block);

let mut service = ServiceBuilder::new().layer(geoblock).service_fn(handle);

let request = Request::builder()
.header("X-Forwarded-For", "127.0.0.1")
.body(Body::empty())
.unwrap();

let response = service.ready().await.unwrap().call(request).await.unwrap();

assert_eq!(response.status(), StatusCode::OK);
}

/// Test that a blocking list with subdivisions doesn't block if the
/// subdivisions don't match, even if the country matches.
#[tokio::test]
async fn test_sub_unblocked_wrong_sub() {
let resolver = LocalResolver::new(Some(resolve_ip), None);
let blocked_countries = vec!["CU:56".into(), "IR".into(), "KP".into()];

let geoblock = GeoBlockLayer::new(&resolver, blocked_countries, BlockingPolicy::Block);

let mut service = ServiceBuilder::new().layer(geoblock).service_fn(handle);

let request = Request::builder()
.header("X-Forwarded-For", "127.0.0.1")
.body(Body::empty())
.unwrap();

let response = service.ready().await.unwrap().call(request).await.unwrap();

assert_eq!(response.status(), StatusCode::OK);
}

/// Test that a blocking list with subdivisions doesn't block if the country
/// doesn't match, even if subdivisions match.
#[tokio::test]
async fn test_sub_unblocked_wrong_country() {
let resolver = LocalResolver::new(Some(resolve_ip), None);
let blocked_countries = vec!["IR:12".into(), "KP".into()];

let geoblock = GeoBlockLayer::new(&resolver, blocked_countries, BlockingPolicy::Block);

let mut service = ServiceBuilder::new().layer(geoblock).service_fn(handle);

let request = Request::builder()
.header("X-Forwarded-For", "127.0.0.1")
.body(Body::empty())
.unwrap();

let response = service.ready().await.unwrap().call(request).await.unwrap();

assert_eq!(response.status(), StatusCode::OK);
}

/// Test that a blocking list with subdivisions blocks containing only one
/// subdivision blocks if the country and subdivision match.
#[tokio::test]
async fn test_sub_blocked_country_sub() {
let resolver = LocalResolver::new(Some(resolve_ip), None);
let blocked_countries = vec!["CU:12".into(), "IR".into(), "KP".into()];

let geoblock = GeoBlockLayer::new(&resolver, blocked_countries, BlockingPolicy::Block);

Expand All @@ -70,10 +167,12 @@ async fn test_reference() {
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}

/// Test that a blocking list with subdivisions blocks containing several
/// subdivisions blocks if the country and subdivision match.
#[tokio::test]
async fn test_arc() {
let resolver = Arc::from(LocalResolver::new(Some(resolve_ip), None));
let blocked_countries = vec!["CU".into(), "IR".into(), "KP".into()];
async fn test_subs_blocked_country_sub() {
let resolver = LocalResolver::new(Some(resolve_ip), None);
let blocked_countries = vec!["CU:12".into(), "CU:34".into(), "IR".into(), "KP".into()];

let geoblock = GeoBlockLayer::new(&resolver, blocked_countries, BlockingPolicy::Block);

Expand All @@ -89,10 +188,33 @@ async fn test_arc() {
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}

/// Test that a blocking list with subdivisions blocks containing several
/// subdivisions in short form blocks if the country and subdivision match.
#[tokio::test]
async fn test_non_blocked_country() {
async fn test_short_subs_blocked_country_sub() {
let resolver = LocalResolver::new(Some(resolve_ip), None);
let blocked_countries = vec!["IR".into(), "KP".into()];
let blocked_countries = vec!["CU:12:34".into(), "IR".into(), "KP".into()];

let geoblock = GeoBlockLayer::new(&resolver, blocked_countries, BlockingPolicy::Block);

let mut service = ServiceBuilder::new().layer(geoblock).service_fn(handle);

let request = Request::builder()
.header("X-Forwarded-For", "127.0.0.1")
.body(Body::empty())
.unwrap();

let response = service.ready().await.unwrap().call(request).await.unwrap();

assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}

/// Test that the blocker doesn't crash if the GeoIP resolution doesn't contain
/// any subdivisions.
#[tokio::test]
async fn test_unresolved_subdivisions() {
let resolver = LocalResolver::new(Some(resolve_ip_no_subs), None);
let blocked_countries = vec!["CU".into(), "IR".into(), "KP".into()];

let geoblock = GeoBlockLayer::new(resolver, blocked_countries, BlockingPolicy::Block);

Expand All @@ -105,5 +227,24 @@ async fn test_non_blocked_country() {

let response = service.ready().await.unwrap().call(request).await.unwrap();

assert_eq!(response.status(), StatusCode::OK);
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}

#[tokio::test]
async fn test_arc() {
let resolver = Arc::from(LocalResolver::new(Some(resolve_ip), None));
let blocked_countries = vec!["CU".into(), "IR".into(), "KP".into()];

let geoblock = GeoBlockLayer::new(&resolver, blocked_countries, BlockingPolicy::Block);

let mut service = ServiceBuilder::new().layer(geoblock).service_fn(handle);

let request = Request::builder()
.header("X-Forwarded-For", "127.0.0.1")
.body(Body::empty())
.unwrap();

let response = service.ready().await.unwrap().call(request).await.unwrap();

assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}
6 changes: 4 additions & 2 deletions examples/geoblock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ fn resolve_ip(_addr: IpAddr) -> geoip2::City<'static> {
country: Some(geoip2::city::Country {
geoname_id: None,
is_in_european_union: None,
iso_code: Some("CU"),
iso_code: Some("IR"),
names: None,
}),
location: None,
Expand All @@ -35,7 +35,9 @@ fn resolve_ip(_addr: IpAddr) -> geoip2::City<'static> {
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let resolver: LocalResolver = LocalResolver::new(Some(|caller| resolve_ip(caller)), None);
let blocked_countries = vec!["CU".into(), "IR".into(), "KP".into()];
// The number after the colon is the ISO code of the subdivision when you don't
// want to block the whole country.
let blocked_countries = vec!["CU:12".into(), "IR".into(), "KP".into()];

let geoblock = GeoBlockLayer::new(resolver, blocked_countries, BlockingPolicy::Block);

Expand Down

0 comments on commit 5f636b8

Please sign in to comment.