Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: enable finer granularity for geo-blocking #10

Merged
merged 2 commits into from
Nov 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 55 additions & 14 deletions crates/geoip/src/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ bitflags! {
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct BlockingPolicy: u8 {
const Block = 0b00;
const AllowMissingCountry = 0b01;
const AllowMissingGeoData = 0b01;
const AllowExtractFailure = 0b10;
const AllowAll = 0b11;
}
Expand All @@ -30,15 +30,34 @@ pub enum Error {
}

#[derive(Debug, Clone)]
pub struct CountryFilter {
blocked_countries: Vec<String>,
struct Zone {
country: String,
subdivisions: Vec<String>,
}

#[derive(Debug, Clone)]
pub struct ZoneFilter {
blocked_zones: Vec<Zone>,
blocking_policy: BlockingPolicy,
}

impl CountryFilter {
pub fn new(blocked_countries: Vec<String>, blocking_policy: BlockingPolicy) -> Self {
impl ZoneFilter {
pub fn new(blocked_zones: Vec<String>, blocking_policy: BlockingPolicy) -> Self {
let blocked_zones = blocked_zones
.iter()
.filter_map(|zone| {
zone.split(':')
.collect::<Vec<_>>()
.split_first()
.map(|(country, subdivisions)| Zone {
country: country.to_string(),
subdivisions: subdivisions.iter().map(|&s| s.to_string()).collect(),
})
})
.collect::<Vec<_>>();

Self {
blocked_countries,
blocked_zones,
blocking_policy,
}
}
Expand All @@ -49,19 +68,41 @@ impl CountryFilter {
where
R: Resolver,
{
let country = resolver
let geo_data = resolver
.lookup_geo_data_raw(addr)
.map_err(|_| Error::UnableToExtractGeoData)?
.map_err(|_| Error::UnableToExtractGeoData)?;

let country = geo_data
.country
.and_then(|country| country.iso_code)
.ok_or(Error::CountryNotFound)?;

let blocked = self
.blocked_countries
.iter()
.any(|blocked_country| blocked_country == country);
let zone_blocked = self.blocked_zones.iter().any(|blocked_zone| {
if blocked_zone.country == country {
if blocked_zone.subdivisions.is_empty() {
true
} else {
geo_data
.subdivisions
.as_deref()
.map_or(false, |subdivisions| {
subdivisions
.iter()
.filter_map(|sub| sub.iso_code)
.any(|sub| {
blocked_zone
.subdivisions
.iter()
.any(|blocked_sub| sub.eq_ignore_ascii_case(blocked_sub))
})
})
}
} else {
false
}
});

if blocked {
if zone_blocked {
Err(Error::Blocked)
} else {
Ok(())
Expand All @@ -75,7 +116,7 @@ impl CountryFilter {
let policy = self.blocking_policy;

let is_blocked = matches!(err, Error::UnableToExtractIPAddress | Error::UnableToExtractGeoData if !policy.contains(BlockingPolicy::AllowExtractFailure))
|| matches!(err, Error::CountryNotFound if !policy.contains(BlockingPolicy::AllowMissingCountry))
|| matches!(err, Error::CountryNotFound if !policy.contains(BlockingPolicy::AllowMissingGeoData))
|| matches!(err, Error::Blocked);

if is_blocked {
Expand Down
10 changes: 5 additions & 5 deletions crates/geoip/src/block/middleware.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
//! See [Router::into_make_service_with_connect_info](https://docs.rs/axum/latest/axum/struct.Router.html#method.into_make_service_with_connect_info) for more details.

use {
super::{BlockingPolicy, CountryFilter, Error},
super::{BlockingPolicy, Error, ZoneFilter},
crate::Resolver,
axum_client_ip::InsecureClientIp,
futures::future::{self, Either, Ready},
Expand All @@ -26,7 +26,7 @@ mod tests;

#[derive(Debug)]
struct Inner<R> {
filter: CountryFilter,
filter: ZoneFilter,
ip_resolver: R,
}

Expand All @@ -52,7 +52,7 @@ where
) -> Self {
Self {
inner: Arc::new(Inner {
filter: CountryFilter::new(blocked_countries, blocking_policy),
filter: ZoneFilter::new(blocked_countries, blocking_policy),
ip_resolver,
}),
}
Expand Down Expand Up @@ -92,13 +92,13 @@ where
pub fn new(
service: S,
ip_resolver: R,
blocked_countries: Vec<String>,
blocked_zones: Vec<String>,
blocking_policy: BlockingPolicy,
) -> Self {
Self {
service,
inner: Arc::new(Inner {
filter: CountryFilter::new(blocked_countries, blocking_policy),
filter: ZoneFilter::new(blocked_zones, blocking_policy),
ip_resolver,
}),
}
Expand Down
161 changes: 151 additions & 10 deletions crates/geoip/src/block/middleware/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ async fn handle(_request: Request<Body>) -> Result<Response<Body>, Infallible> {
Ok(Response::new(Body::empty()))
}

fn resolve_ip(_addr: IpAddr) -> City<'static> {
fn resolve_ip_no_subs(_addr: IpAddr) -> City<'static> {
City {
city: None,
continent: None,
Expand All @@ -32,8 +32,40 @@ fn resolve_ip(_addr: IpAddr) -> City<'static> {
}
}

fn resolve_ip(_addr: IpAddr) -> City<'static> {
City {
city: None,
continent: None,
country: Some(geoip2::city::Country {
geoname_id: None,
is_in_european_union: None,
iso_code: Some("CU"),
names: None,
}),
location: None,
postal: None,
registered_country: None,
represented_country: None,
subdivisions: Some(vec![
geoip2::city::Subdivision {
geoname_id: None,
iso_code: Some("12"),
names: None,
},
geoip2::city::Subdivision {
geoname_id: None,
iso_code: Some("34"),
names: None,
},
]),
traits: None,
}
}

/// Test that a blocking list with no subdivisions blocks the country if
/// a match is found.
#[tokio::test]
async fn test_blocked_country() {
async fn test_country_blocked() {
let resolver = LocalResolver::new(Some(resolve_ip), None);
let blocked_countries = vec!["CU".into(), "IR".into(), "KP".into()];

Expand All @@ -51,10 +83,75 @@ async fn test_blocked_country() {
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}

/// Test that a blocking list with no subdivisions doesn't block if the
/// country doesn't match.
#[tokio::test]
async fn test_reference() {
async fn test_country_non_blocked() {
let resolver = LocalResolver::new(Some(resolve_ip), None);
let blocked_countries = vec!["CU".into(), "IR".into(), "KP".into()];
let blocked_countries = vec!["IR".into(), "KP".into()];

let geoblock = GeoBlockLayer::new(resolver, blocked_countries, BlockingPolicy::Block);

let mut service = ServiceBuilder::new().layer(geoblock).service_fn(handle);

let request = Request::builder()
.header("X-Forwarded-For", "127.0.0.1")
.body(Body::empty())
.unwrap();

let response = service.ready().await.unwrap().call(request).await.unwrap();

assert_eq!(response.status(), StatusCode::OK);
}

/// Test that a blocking list with subdivisions doesn't block if the
/// subdivisions don't match, even if the country matches.
#[tokio::test]
async fn test_sub_unblocked_wrong_sub() {
let resolver = LocalResolver::new(Some(resolve_ip), None);
let blocked_countries = vec!["CU:56".into(), "IR".into(), "KP".into()];

let geoblock = GeoBlockLayer::new(&resolver, blocked_countries, BlockingPolicy::Block);

let mut service = ServiceBuilder::new().layer(geoblock).service_fn(handle);

let request = Request::builder()
.header("X-Forwarded-For", "127.0.0.1")
.body(Body::empty())
.unwrap();

let response = service.ready().await.unwrap().call(request).await.unwrap();

assert_eq!(response.status(), StatusCode::OK);
}

/// Test that a blocking list with subdivisions doesn't block if the country
/// doesn't match, even if subdivisions match.
#[tokio::test]
async fn test_sub_unblocked_wrong_country() {
let resolver = LocalResolver::new(Some(resolve_ip), None);
let blocked_countries = vec!["IR:12".into(), "KP".into()];

let geoblock = GeoBlockLayer::new(&resolver, blocked_countries, BlockingPolicy::Block);

let mut service = ServiceBuilder::new().layer(geoblock).service_fn(handle);

let request = Request::builder()
.header("X-Forwarded-For", "127.0.0.1")
.body(Body::empty())
.unwrap();

let response = service.ready().await.unwrap().call(request).await.unwrap();

assert_eq!(response.status(), StatusCode::OK);
}

/// Test that a blocking list with subdivisions blocks containing only one
/// subdivision blocks if the country and subdivision match.
#[tokio::test]
async fn test_sub_blocked_country_sub() {
let resolver = LocalResolver::new(Some(resolve_ip), None);
let blocked_countries = vec!["CU:12".into(), "IR".into(), "KP".into()];

let geoblock = GeoBlockLayer::new(&resolver, blocked_countries, BlockingPolicy::Block);

Expand All @@ -70,10 +167,12 @@ async fn test_reference() {
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}

/// Test that a blocking list with subdivisions blocks containing several
/// subdivisions blocks if the country and subdivision match.
#[tokio::test]
async fn test_arc() {
let resolver = Arc::from(LocalResolver::new(Some(resolve_ip), None));
let blocked_countries = vec!["CU".into(), "IR".into(), "KP".into()];
async fn test_subs_blocked_country_sub() {
let resolver = LocalResolver::new(Some(resolve_ip), None);
let blocked_countries = vec!["CU:12".into(), "CU:34".into(), "IR".into(), "KP".into()];

let geoblock = GeoBlockLayer::new(&resolver, blocked_countries, BlockingPolicy::Block);

Expand All @@ -89,10 +188,33 @@ async fn test_arc() {
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}

/// Test that a blocking list with subdivisions blocks containing several
/// subdivisions in short form blocks if the country and subdivision match.
#[tokio::test]
async fn test_non_blocked_country() {
async fn test_short_subs_blocked_country_sub() {
let resolver = LocalResolver::new(Some(resolve_ip), None);
let blocked_countries = vec!["IR".into(), "KP".into()];
let blocked_countries = vec!["CU:12:34".into(), "IR".into(), "KP".into()];

let geoblock = GeoBlockLayer::new(&resolver, blocked_countries, BlockingPolicy::Block);

let mut service = ServiceBuilder::new().layer(geoblock).service_fn(handle);

let request = Request::builder()
.header("X-Forwarded-For", "127.0.0.1")
.body(Body::empty())
.unwrap();

let response = service.ready().await.unwrap().call(request).await.unwrap();

assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}

/// Test that the blocker doesn't crash if the GeoIP resolution doesn't contain
/// any subdivisions.
#[tokio::test]
async fn test_unresolved_subdivisions() {
let resolver = LocalResolver::new(Some(resolve_ip_no_subs), None);
let blocked_countries = vec!["CU".into(), "IR".into(), "KP".into()];

let geoblock = GeoBlockLayer::new(resolver, blocked_countries, BlockingPolicy::Block);

Expand All @@ -105,5 +227,24 @@ async fn test_non_blocked_country() {

let response = service.ready().await.unwrap().call(request).await.unwrap();

assert_eq!(response.status(), StatusCode::OK);
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}

#[tokio::test]
async fn test_arc() {
let resolver = Arc::from(LocalResolver::new(Some(resolve_ip), None));
let blocked_countries = vec!["CU".into(), "IR".into(), "KP".into()];

let geoblock = GeoBlockLayer::new(&resolver, blocked_countries, BlockingPolicy::Block);

let mut service = ServiceBuilder::new().layer(geoblock).service_fn(handle);

let request = Request::builder()
.header("X-Forwarded-For", "127.0.0.1")
.body(Body::empty())
.unwrap();

let response = service.ready().await.unwrap().call(request).await.unwrap();

assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}
6 changes: 4 additions & 2 deletions examples/geoblock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ fn resolve_ip(_addr: IpAddr) -> geoip2::City<'static> {
country: Some(geoip2::city::Country {
geoname_id: None,
is_in_european_union: None,
iso_code: Some("CU"),
iso_code: Some("IR"),
names: None,
}),
location: None,
Expand All @@ -35,7 +35,9 @@ fn resolve_ip(_addr: IpAddr) -> geoip2::City<'static> {
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let resolver: LocalResolver = LocalResolver::new(Some(|caller| resolve_ip(caller)), None);
let blocked_countries = vec!["CU".into(), "IR".into(), "KP".into()];
// The number after the colon is the ISO code of the subdivision when you don't
// want to block the whole country.
let blocked_countries = vec!["CU:12".into(), "IR".into(), "KP".into()];

let geoblock = GeoBlockLayer::new(resolver, blocked_countries, BlockingPolicy::Block);

Expand Down
Loading