Skip to content

Commit

Permalink
fix: add Resolver implementations for references and Arc<> (#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
xav authored Sep 29, 2023
1 parent c5d0d06 commit 356602a
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 2 deletions.
40 changes: 39 additions & 1 deletion crates/geoip/src/block/middleware/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use {
},
hyper::{Body, Request, Response, StatusCode},
maxminddb::{geoip2, geoip2::City},
std::{convert::Infallible, net::IpAddr},
std::{convert::Infallible, net::IpAddr, sync::Arc},
tower::{Service, ServiceBuilder, ServiceExt},
};

Expand Down Expand Up @@ -51,6 +51,44 @@ async fn test_blocked_country() {
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}

#[tokio::test]
async fn test_reference() {
let resolver = LocalResolver::new(Some(resolve_ip), None);
let blocked_countries = vec!["CU".into(), "IR".into(), "KP".into()];

let geoblock = GeoBlockLayer::new(&resolver, blocked_countries, BlockingPolicy::Block);

let mut service = ServiceBuilder::new().layer(geoblock).service_fn(handle);

let request = Request::builder()
.header("X-Forwarded-For", "127.0.0.1")
.body(Body::empty())
.unwrap();

let response = service.ready().await.unwrap().call(request).await.unwrap();

assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}

#[tokio::test]
async fn test_arc() {
let resolver = Arc::from(LocalResolver::new(Some(resolve_ip), None));
let blocked_countries = vec!["CU".into(), "IR".into(), "KP".into()];

let geoblock = GeoBlockLayer::new(&resolver, blocked_countries, BlockingPolicy::Block);

let mut service = ServiceBuilder::new().layer(geoblock).service_fn(handle);

let request = Request::builder()
.header("X-Forwarded-For", "127.0.0.1")
.body(Body::empty())
.unwrap();

let response = service.ready().await.unwrap().call(request).await.unwrap();

assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}

#[tokio::test]
async fn test_non_blocked_country() {
let resolver = LocalResolver::new(Some(resolve_ip), None);
Expand Down
36 changes: 35 additions & 1 deletion crates/geoip/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use {
},
bytes::Bytes,
maxminddb::geoip2::City,
std::{net::IpAddr, sync::Arc},
std::{net::IpAddr, ops::Deref, sync::Arc},
};

pub mod block;
Expand All @@ -32,6 +32,40 @@ pub trait Resolver: Clone {
fn lookup_geo_data(&self, addr: IpAddr) -> Result<Data, Self::Error>;
}

impl<'a, T> Resolver for &'a T
where
T: Resolver,
{
type Error = T::Error;

fn lookup_geo_data_raw(&self, addr: IpAddr) -> Result<City<'_>, Self::Error> {
let r = <&T>::deref(self);
r.lookup_geo_data_raw(addr)
}

fn lookup_geo_data(&self, addr: IpAddr) -> Result<Data, Self::Error> {
let r = <&T>::deref(self);
r.lookup_geo_data(addr)
}
}

impl<T> Resolver for Arc<T>
where
T: Resolver,
{
type Error = T::Error;

fn lookup_geo_data_raw(&self, addr: IpAddr) -> Result<City<'_>, Self::Error> {
let r = self.deref();
r.lookup_geo_data_raw(addr)
}

fn lookup_geo_data(&self, addr: IpAddr) -> Result<Data, Self::Error> {
let r = self.deref();
r.lookup_geo_data(addr)
}
}

#[derive(Debug, thiserror::Error)]
pub enum LocalResolverError {
#[error("Geoip data lookup is not supported")]
Expand Down

0 comments on commit 356602a

Please sign in to comment.