From c5d0d0670e68cf15c6dcd411ffd3337fa0ba1d98 Mon Sep 17 00:00:00 2001 From: Ivan Reshetnikov Date: Fri, 29 Sep 2023 11:47:56 +0200 Subject: [PATCH] feat: geoip improvements (#7) --- Cargo.toml | 3 +- crates/analytics/Cargo.toml | 4 +- crates/geoblock/Cargo.toml | 24 -- crates/geoblock/src/lib.rs | 264 ------------------ crates/geoip/Cargo.toml | 21 +- crates/geoip/src/block.rs | 90 ++++++ crates/geoip/src/block/middleware.rs | 151 ++++++++++ .../src/block/middleware}/tests.rs | 46 +-- crates/geoip/src/lib.rs | 139 ++++++++- crates/geoip/src/local.rs | 24 -- crates/geoip/src/maxmind.rs | 82 ------ examples/geoblock.rs | 45 +-- src/lib.rs | 2 - 13 files changed, 442 insertions(+), 453 deletions(-) delete mode 100644 crates/geoblock/Cargo.toml delete mode 100644 crates/geoblock/src/lib.rs create mode 100644 crates/geoip/src/block.rs create mode 100644 crates/geoip/src/block/middleware.rs rename crates/{geoblock/src => geoip/src/block/middleware}/tests.rs (62%) delete mode 100644 crates/geoip/src/local.rs delete mode 100644 crates/geoip/src/maxmind.rs diff --git a/Cargo.toml b/Cargo.toml index 446e054..ff5928b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,7 +28,7 @@ alloc = ["dep:alloc"] analytics = ["dep:analytics"] collections = ["dep:collections"] future = ["dep:future"] -geoblock = ["dep:geoblock", "dep:geoip"] +geoblock = ["geoip/middleware"] geoip = ["dep:geoip"] http = [] metrics = ["dep:metrics", "future/metrics", "alloc/metrics", "http/metrics"] @@ -39,7 +39,6 @@ alloc = { path = "./crates/alloc", optional = true } analytics = { path = "./crates/analytics", optional = true } collections = { path = "./crates/collections", optional = true } future = { path = "./crates/future", optional = true } -geoblock = { path = "./crates/geoblock", optional = true } geoip = { path = "./crates/geoip", optional = true } http = { path = "./crates/http", optional = true } metrics = { path = "./crates/metrics", optional = true } diff --git a/crates/analytics/Cargo.toml b/crates/analytics/Cargo.toml index f7505fa..f4b35f7 100644 --- a/crates/analytics/Cargo.toml +++ b/crates/analytics/Cargo.toml @@ -14,7 +14,7 @@ anyhow = "1" tap = "1.0" chrono = { version = "0.4" } -aws-sdk-s3 = "0.25" -bytes = "1.2" +aws-sdk-s3 = "0.31" +bytes = "1.5" parquet = { git = "https://github.com/WalletConnect/arrow-rs.git", rev = "99a1cc3", default-features = false, features = ["flate2"] } parquet_derive = { git = "https://github.com/WalletConnect/arrow-rs.git", rev = "99a1cc3" } diff --git a/crates/geoblock/Cargo.toml b/crates/geoblock/Cargo.toml deleted file mode 100644 index de024ae..0000000 --- a/crates/geoblock/Cargo.toml +++ /dev/null @@ -1,24 +0,0 @@ -[package] -name = "geoblock" -version = "0.1.0" -edition = "2021" - -[features] -default = [] - -[dependencies] -tracing = "0.1" -axum = "0.6" -thiserror = "1.0" -hyper = "0.14" -tower = "0.4" -tower-layer = "0.3" -pin-project = "1" -futures-core = "0.3" -http-body = "0.4" -axum-client-ip = "0.4" -geoip = { path = "../geoip" } - -[dev-dependencies] -tokio = { version = "1", features = ["full"] } -axum = "0.6" diff --git a/crates/geoblock/src/lib.rs b/crates/geoblock/src/lib.rs deleted file mode 100644 index f9caafa..0000000 --- a/crates/geoblock/src/lib.rs +++ /dev/null @@ -1,264 +0,0 @@ -//! Middleware which adds geo-location IP blocking. -//! -//! Note: this middleware requires you to use -//! [Router::into_make_service_with_connect_info](https://docs.rs/axum/latest/axum/struct.Router.html#method.into_make_service_with_connect_info) -//! to run your app otherwise it will fail at runtime. -//! -//! See [Router::into_make_service_with_connect_info](https://docs.rs/axum/latest/axum/struct.Router.html#method.into_make_service_with_connect_info) for more details. -pub use geoip; -use { - axum_client_ip::InsecureClientIp, - geoip::GeoIpResolver, - http_body::Body, - hyper::{Request, Response, StatusCode}, - pin_project::pin_project, - std::{ - future::Future, - net::IpAddr, - pin::Pin, - task::{Context, Poll}, - }, - thiserror::Error, - tower::Service, - tower_layer::Layer, - tracing::{error, info}, -}; - -#[cfg(test)] -mod tests; - -/// Values used to configure the middleware behavior when country information -/// could not be retrieved. -#[derive(Debug, Clone, Copy, PartialEq)] -pub enum BlockingPolicy { - Block, - AllowMissingCountry, - AllowExtractFailure, - AllowAll, -} - -#[derive(Debug, Error)] -enum GeoBlockError { - #[error("Country is blocked")] - Blocked, - - #[error("Unable to extract IP address")] - UnableToExtractIPAddress, - - #[error("Unable to extract geo data from IP address")] - UnableToExtractGeoData, - - #[error("Country could not be found in database")] - CountryNotFound, -} - -/// Layer that applies the GeoBlock middleware which blocks requests base on IP -/// geo-location. -#[derive(Debug, Clone)] -#[must_use] -pub struct GeoBlockLayer -where - R: GeoIpResolver, -{ - blocked_countries: Vec, - ip_resolver: R, - blocking_policy: BlockingPolicy, -} - -impl GeoBlockLayer -where - R: GeoIpResolver, -{ - pub fn new( - ip_resolver: R, - blocked_countries: Vec, - blocking_policy: BlockingPolicy, - ) -> Self { - Self { - ip_resolver, - blocked_countries, - blocking_policy, - } - } -} - -impl Layer for GeoBlockLayer -where - R: GeoIpResolver, -{ - type Service = GeoBlockService; - - fn layer(&self, inner: S) -> Self::Service { - GeoBlockService::new( - inner, - self.ip_resolver.clone(), - self.blocked_countries.clone(), - self.blocking_policy, - ) - } -} - -/// Layer that applies the GeoBlock middleware which blocks requests base on IP -/// geo-location. -#[derive(Debug, Clone)] -#[must_use] -pub struct GeoBlockService -where - R: GeoIpResolver, -{ - inner: S, - blocked_countries: Vec, - ip_resolver: R, - blocking_policy: BlockingPolicy, -} - -impl GeoBlockService -where - R: GeoIpResolver, -{ - pub fn new( - inner: S, - ip_resolver: R, - blocked_countries: Vec, - blocking_policy: BlockingPolicy, - ) -> Self { - Self { - inner, - blocking_policy, - blocked_countries, - ip_resolver, - } - } - - /// Extracts the IP address from the request. - fn extract_ip(&self, req: &Request) -> Result { - let client_ip = InsecureClientIp::from(req.headers(), req.extensions()) - .map_err(|_| GeoBlockError::UnableToExtractIPAddress)?; - Ok(client_ip.0) - } - - /// Checks if the specified IP address is allowed to access the service. - fn check_ip(&self, caller: IpAddr) -> Result<(), GeoBlockError> { - let country = self - .ip_resolver - .lookup_geo_data(caller) - .map_err(|_| GeoBlockError::UnableToExtractGeoData)? - .country - .ok_or(GeoBlockError::CountryNotFound)?; - - if self - .blocked_countries - .iter() - .any(|blocked_country| *blocked_country == *country) - { - Err(GeoBlockError::Blocked) - } else { - Ok(()) - } - } - - fn check_caller(&self, req: &Request) -> Result<(), GeoBlockError> { - self.check_ip(self.extract_ip(req)?) - } -} - -impl Service> for GeoBlockService -where - S: Service, Response = Response>, - R: GeoIpResolver, - ResBody: Body + Default, -{ - type Error = S::Error; - type Future = ResponseFuture; - type Response = S::Response; - - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - self.inner.poll_ready(cx) - } - - fn call(&mut self, request: Request) -> Self::Future { - match self.check_caller(&request) { - Ok(_) => ResponseFuture::future(self.inner.call(request)), - - Err(GeoBlockError::Blocked) => { - let mut res = Response::new(ResBody::default()); - *res.status_mut() = StatusCode::UNAUTHORIZED; - ResponseFuture::error(res) - } - - Err(GeoBlockError::UnableToExtractIPAddress) - | Err(GeoBlockError::UnableToExtractGeoData) => { - if self.blocking_policy == BlockingPolicy::AllowExtractFailure { - info!("Unable to extract client IP address"); - return ResponseFuture::future(self.inner.call(request)); - } - - let mut res = Response::new(ResBody::default()); - *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; - ResponseFuture::error(res) - } - - Err(e) => { - if self.blocking_policy == BlockingPolicy::AllowMissingCountry { - error!("Unable to extract client IP address: {}", e); - return ResponseFuture::future(self.inner.call(request)); - } - - let mut res = Response::new(ResBody::default()); - *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; - ResponseFuture::error(res) - } - } - } -} - -/// Response future for [`GeoBlock`]. -#[pin_project] -pub struct ResponseFuture { - #[pin] - inner: Kind, -} - -#[pin_project(project = KindProj)] -enum Kind { - Future { - #[pin] - future: F, - }, - Error { - response: Option>, - }, -} - -impl ResponseFuture { - fn future(future: F) -> Self { - Self { - inner: Kind::Future { future }, - } - } - - fn error(res: Response) -> Self { - Self { - inner: Kind::Error { - response: Some(res), - }, - } - } -} - -impl Future for ResponseFuture -where - F: Future, E>>, -{ - type Output = F::Output; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - match self.project().inner.project() { - KindProj::Future { future } => future.poll(cx), - KindProj::Error { response } => { - let response = response.take().unwrap(); - Poll::Ready(Ok(response)) - } - } - } -} diff --git a/crates/geoip/Cargo.toml b/crates/geoip/Cargo.toml index c7a961d..b8f7748 100644 --- a/crates/geoip/Cargo.toml +++ b/crates/geoip/Cargo.toml @@ -3,10 +3,25 @@ name = "geoip" version = "0.1.0" edition = "2021" -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[features] +default = [] +full = ["middleware"] +middleware = ["dep:tower", "dep:tower-layer", "dep:axum-client-ip", "dep:http-body"] [dependencies] +tower = { version = "0.4", optional = true } +tower-layer = { version = "0.3", optional = true } +http-body = { version = "0.4", optional = true } +axum-client-ip = { version = "0.4", optional = true } +bitflags = "2.4" +hyper = "0.14" +tracing = "0.1" +thiserror = "1.0" +futures = "0.3" bytes = "1.5" -aws-sdk-s3 = "0.31.0" +aws-sdk-s3 = "0.31" maxminddb = "0.23" -thiserror = "1.0" + +[dev-dependencies] +tokio = { version = "1", features = ["full"] } +axum = "0.6" diff --git a/crates/geoip/src/block.rs b/crates/geoip/src/block.rs new file mode 100644 index 0000000..bd0622b --- /dev/null +++ b/crates/geoip/src/block.rs @@ -0,0 +1,90 @@ +use {crate::Resolver, bitflags::bitflags, std::net::IpAddr}; + +#[cfg(feature = "middleware")] +pub mod middleware; + +bitflags! { + /// Values used to configure the response behavior when geo data could not be retrieved. + #[derive(Debug, Clone, Copy, PartialEq, Eq)] + pub struct BlockingPolicy: u8 { + const Block = 0b00; + const AllowMissingCountry = 0b01; + const AllowExtractFailure = 0b10; + const AllowAll = 0b11; + } +} + +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("Country is blocked")] + Blocked, + + #[error("Unable to extract IP address")] + UnableToExtractIPAddress, + + #[error("Unable to extract geo data from IP address")] + UnableToExtractGeoData, + + #[error("Country could not be found in database")] + CountryNotFound, +} + +#[derive(Debug, Clone)] +pub struct CountryFilter { + blocked_countries: Vec, + blocking_policy: BlockingPolicy, +} + +impl CountryFilter { + pub fn new(blocked_countries: Vec, blocking_policy: BlockingPolicy) -> Self { + Self { + blocked_countries, + blocking_policy, + } + } + + /// Checks whether the IP address is blocked. Returns an error if it's + /// blocked or if the lookup has failed for any reason. + pub fn check(&self, addr: IpAddr, resolver: &R) -> Result<(), Error> + where + R: Resolver, + { + let country = resolver + .lookup_geo_data_raw(addr) + .map_err(|_| Error::UnableToExtractGeoData)? + .country + .and_then(|country| country.iso_code) + .ok_or(Error::CountryNotFound)?; + + let blocked = self + .blocked_countries + .iter() + .any(|blocked_country| blocked_country == country); + + if blocked { + Err(Error::Blocked) + } else { + Ok(()) + } + } + + /// Applies selected blocking policy to the [`Blacklist::check()`] result, + /// which may ignore some of the errors. + pub fn apply_policy(&self, check_result: Result<(), Error>) -> Result<(), Error> { + if let Err(err) = check_result { + let policy = self.blocking_policy; + + let is_blocked = matches!(err, Error::UnableToExtractIPAddress | Error::UnableToExtractGeoData if !policy.contains(BlockingPolicy::AllowExtractFailure)) + || matches!(err, Error::CountryNotFound if !policy.contains(BlockingPolicy::AllowMissingCountry)) + || matches!(err, Error::Blocked); + + if is_blocked { + Err(err) + } else { + Ok(()) + } + } else { + Ok(()) + } + } +} diff --git a/crates/geoip/src/block/middleware.rs b/crates/geoip/src/block/middleware.rs new file mode 100644 index 0000000..0a82e95 --- /dev/null +++ b/crates/geoip/src/block/middleware.rs @@ -0,0 +1,151 @@ +//! Middleware which adds geo-location IP blocking. +//! +//! Note: this middleware requires you to use +//! [Router::into_make_service_with_connect_info](https://docs.rs/axum/latest/axum/struct.Router.html#method.into_make_service_with_connect_info) +//! to run your app otherwise it will fail at runtime. +//! +//! See [Router::into_make_service_with_connect_info](https://docs.rs/axum/latest/axum/struct.Router.html#method.into_make_service_with_connect_info) for more details. + +use { + super::{BlockingPolicy, CountryFilter, Error}, + crate::Resolver, + axum_client_ip::InsecureClientIp, + futures::future::{self, Either, Ready}, + http_body::Body, + hyper::{Request, Response, StatusCode}, + std::{ + sync::Arc, + task::{Context, Poll}, + }, + tower::Service, + tower_layer::Layer, +}; + +#[cfg(test)] +mod tests; + +#[derive(Debug)] +struct Inner { + filter: CountryFilter, + ip_resolver: R, +} + +/// Layer that applies the GeoBlock middleware which blocks requests base on IP +/// geo-location. +#[derive(Debug, Clone)] +#[must_use] +pub struct GeoBlockLayer +where + R: Resolver, +{ + inner: Arc>, +} + +impl GeoBlockLayer +where + R: Resolver, +{ + pub fn new( + ip_resolver: R, + blocked_countries: Vec, + blocking_policy: BlockingPolicy, + ) -> Self { + Self { + inner: Arc::new(Inner { + filter: CountryFilter::new(blocked_countries, blocking_policy), + ip_resolver, + }), + } + } +} + +impl Layer for GeoBlockLayer +where + R: Resolver, +{ + type Service = GeoBlockService; + + fn layer(&self, service: S) -> Self::Service { + GeoBlockService { + service, + inner: self.inner.clone(), + } + } +} + +/// Layer that applies the GeoBlock middleware which blocks requests base on IP +/// geo-location. +#[derive(Debug, Clone)] +#[must_use] +pub struct GeoBlockService +where + R: Resolver, +{ + service: S, + inner: Arc>, +} + +impl GeoBlockService +where + R: Resolver, +{ + pub fn new( + service: S, + ip_resolver: R, + blocked_countries: Vec, + blocking_policy: BlockingPolicy, + ) -> Self { + Self { + service, + inner: Arc::new(Inner { + filter: CountryFilter::new(blocked_countries, blocking_policy), + ip_resolver, + }), + } + } +} + +impl Service> for GeoBlockService +where + S: Service, Response = Response>, + R: Resolver, + ResBody: Body + Default, +{ + type Error = S::Error; + type Future = Either, S::Error>>>; + type Response = S::Response; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.service.poll_ready(cx) + } + + fn call(&mut self, request: Request) -> Self::Future { + let inner = self.inner.as_ref(); + + let result = InsecureClientIp::from(request.headers(), request.extensions()) + .map_err(|_| Error::UnableToExtractIPAddress) + .and_then(|client_ip| inner.filter.check(client_ip.0, &inner.ip_resolver)); + + match inner.filter.apply_policy(result) { + Ok(_) => Either::Left(self.service.call(request)), + + Err(err) => { + let code = match err { + Error::Blocked => StatusCode::UNAUTHORIZED, + Error::UnableToExtractIPAddress + | Error::UnableToExtractGeoData + | Error::CountryNotFound => { + tracing::warn!(?err, "failed to check geoblocking"); + + StatusCode::INTERNAL_SERVER_ERROR + } + }; + + let mut response = Response::new(ResBody::default()); + *response.status_mut() = code; + + Either::Right(future::ok(response)) + } + } + } +} diff --git a/crates/geoblock/src/tests.rs b/crates/geoip/src/block/middleware/tests.rs similarity index 62% rename from crates/geoblock/src/tests.rs rename to crates/geoip/src/block/middleware/tests.rs index 51d9a78..12ae1d0 100644 --- a/crates/geoblock/src/tests.rs +++ b/crates/geoip/src/block/middleware/tests.rs @@ -1,10 +1,11 @@ use { - crate::{geoip, BlockingPolicy, GeoBlockLayer}, - hyper::{Body, Request, Response, StatusCode}, - std::{ - convert::Infallible, - net::{IpAddr, Ipv4Addr}, + crate::{ + block::{middleware::GeoBlockLayer, BlockingPolicy}, + LocalResolver, }, + hyper::{Body, Request, Response, StatusCode}, + maxminddb::{geoip2, geoip2::City}, + std::{convert::Infallible, net::IpAddr}, tower::{Service, ServiceBuilder, ServiceExt}, }; @@ -12,27 +13,28 @@ async fn handle(_request: Request) -> Result, Infallible> { Ok(Response::new(Body::empty())) } -fn resolve_ip(caller: IpAddr) -> geoip::GeoData { - if IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)) == caller { - geoip::GeoData { - continent: Some("NA".to_string().into()), - country: Some("CU".to_string().into()), - region: None, - city: None, - } - } else { - geoip::GeoData { - continent: Some("NA".to_string().into()), - country: Some("US".to_string().into()), - region: None, - city: None, - } +fn resolve_ip(_addr: IpAddr) -> City<'static> { + City { + city: None, + continent: None, + country: Some(geoip2::city::Country { + geoname_id: None, + is_in_european_union: None, + iso_code: Some("CU"), + names: None, + }), + location: None, + postal: None, + registered_country: None, + represented_country: None, + subdivisions: None, + traits: None, } } #[tokio::test] async fn test_blocked_country() { - let resolver: geoip::local::LocalResolver = geoip::local::LocalResolver::new(resolve_ip); + let resolver = LocalResolver::new(Some(resolve_ip), None); let blocked_countries = vec!["CU".into(), "IR".into(), "KP".into()]; let geoblock = GeoBlockLayer::new(resolver, blocked_countries, BlockingPolicy::Block); @@ -51,7 +53,7 @@ async fn test_blocked_country() { #[tokio::test] async fn test_non_blocked_country() { - let resolver: geoip::local::LocalResolver = geoip::local::LocalResolver::new(resolve_ip); + let resolver = LocalResolver::new(Some(resolve_ip), None); let blocked_countries = vec!["IR".into(), "KP".into()]; let geoblock = GeoBlockLayer::new(resolver, blocked_countries, BlockingPolicy::Block); diff --git a/crates/geoip/src/lib.rs b/crates/geoip/src/lib.rs index 3f65033..14307d9 100644 --- a/crates/geoip/src/lib.rs +++ b/crates/geoip/src/lib.rs @@ -1,20 +1,147 @@ -use std::{net::IpAddr, sync::Arc}; +pub use maxminddb; +use { + aws_sdk_s3::{ + error::SdkError, + operation::get_object::GetObjectError, + primitives::ByteStreamError, + Client as S3Client, + }, + bytes::Bytes, + maxminddb::geoip2::City, + std::{net::IpAddr, sync::Arc}, +}; -pub mod local; -pub mod maxmind; +pub mod block; #[derive(Debug, Clone)] -pub struct GeoData { +pub struct Data { pub continent: Option>, pub country: Option>, pub region: Option>, pub city: Option>, } -pub trait GeoIpResolver: Clone { +pub trait Resolver: Clone { /// The error type produced by the resolver. type Error; + /// Lookup the raw geo data for the given IP address. + fn lookup_geo_data_raw(&self, addr: IpAddr) -> Result, Self::Error>; + /// Lookup the geo data for the given IP address. - fn lookup_geo_data(&self, addr: IpAddr) -> Result; + fn lookup_geo_data(&self, addr: IpAddr) -> Result; +} + +#[derive(Debug, thiserror::Error)] +pub enum LocalResolverError { + #[error("Geoip data lookup is not supported")] + NotSupported, +} + +/// Local resolver that does not need DB files. +#[derive(Debug, Clone)] +pub struct LocalResolver { + resolver_raw: Option City<'static>>, + resolver: Option Data>, +} + +impl LocalResolver { + pub fn new( + resolver_raw: Option City<'static>>, + resolver: Option Data>, + ) -> Self { + Self { + resolver_raw, + resolver, + } + } +} + +impl Resolver for LocalResolver { + type Error = LocalResolverError; + + fn lookup_geo_data_raw(&self, addr: IpAddr) -> Result, Self::Error> { + self.resolver_raw + .ok_or(LocalResolverError::NotSupported) + .map(|resolver| resolver(addr)) + } + + fn lookup_geo_data(&self, addr: IpAddr) -> Result { + self.resolver + .ok_or(LocalResolverError::NotSupported) + .map(|resolver| resolver(addr)) + } +} + +#[derive(Debug, thiserror::Error)] +pub enum MaxMindResolverError { + #[error("S3 get object failed: {0}")] + GetObject(#[from] SdkError), + + #[error("Byte stream error: {0}")] + ByteStream(#[from] ByteStreamError), + + #[error("MaxMind DB lookup error: {0}")] + MaxMindDB(#[from] maxminddb::MaxMindDBError), +} + +#[derive(Debug, Clone)] +pub struct MaxMindResolver { + reader: Arc>, +} + +impl MaxMindResolver { + pub async fn from_aws_s3( + s3_client: &S3Client, + bucket: impl Into, + key: impl Into, + ) -> Result { + let s3_object = s3_client + .get_object() + .bucket(bucket) + .key(key) + .send() + .await?; + let geo_data = s3_object.body.collect().await?.into_bytes(); + + Self::from_buffer(geo_data) + } + + pub fn from_buffer(buffer: Bytes) -> Result { + let reader = maxminddb::Reader::from_source(buffer)?; + Ok(Self { + reader: Arc::new(reader), + }) + } +} + +impl Resolver for MaxMindResolver { + type Error = MaxMindResolverError; + + fn lookup_geo_data_raw(&self, addr: IpAddr) -> Result, Self::Error> { + self.reader.lookup::(addr).map_err(Into::into) + } + + fn lookup_geo_data(&self, addr: IpAddr) -> Result { + let lookup_data = self.lookup_geo_data_raw(addr)?; + + Ok(Data { + continent: lookup_data + .continent + .and_then(|continent| continent.code.map(Into::into)), + country: lookup_data + .country + .and_then(|country| country.iso_code.map(Into::into)), + region: lookup_data.subdivisions.map(|divs| { + divs.into_iter() + .filter_map(|div| div.iso_code) + .map(Into::into) + .collect() + }), + city: lookup_data + .city + .and_then(|city| city.names) + .and_then(|city_names| city_names.get("en").copied().map(Into::into)), + }) + } } diff --git a/crates/geoip/src/local.rs b/crates/geoip/src/local.rs deleted file mode 100644 index 23a2f40..0000000 --- a/crates/geoip/src/local.rs +++ /dev/null @@ -1,24 +0,0 @@ -use { - crate::{GeoData, GeoIpResolver}, - std::{convert::Infallible, net::IpAddr}, -}; - -/// Local resolver that does not need DB files. -#[derive(Debug, Clone)] -pub struct LocalResolver { - resolver: fn(IpAddr) -> GeoData, -} - -impl LocalResolver { - pub fn new(resolver: fn(IpAddr) -> GeoData) -> Self { - Self { resolver } - } -} - -impl GeoIpResolver for LocalResolver { - type Error = Infallible; - - fn lookup_geo_data(&self, addr: IpAddr) -> Result { - Ok((self.resolver)(addr)) - } -} diff --git a/crates/geoip/src/maxmind.rs b/crates/geoip/src/maxmind.rs deleted file mode 100644 index 3361b36..0000000 --- a/crates/geoip/src/maxmind.rs +++ /dev/null @@ -1,82 +0,0 @@ -use { - crate::{GeoData, GeoIpResolver}, - aws_sdk_s3::{ - error::SdkError, - operation::get_object::GetObjectError, - primitives::ByteStreamError, - Client as S3Client, - }, - bytes::Bytes, - maxminddb::geoip2::City, - std::{net::IpAddr, sync::Arc}, - thiserror::Error, -}; - -#[derive(Debug, Error)] -pub enum MaxMindResolverError { - #[error("S3 get object failed: {0}")] - GetObject(#[from] SdkError), - - #[error("Byte stream error: {0}")] - ByteStream(#[from] ByteStreamError), - - #[error("MaxMind DB lookup error: {0}")] - MaxMindDB(#[from] maxminddb::MaxMindDBError), -} - -#[derive(Debug, Clone)] -pub struct MaxMindResolver { - reader: Arc>, -} - -impl MaxMindResolver { - pub async fn from_aws_s3( - s3_client: &S3Client, - bucket: impl Into, - key: impl Into, - ) -> Result { - let s3_object = s3_client - .get_object() - .bucket(bucket) - .key(key) - .send() - .await?; - let geo_data = s3_object.body.collect().await?.into_bytes(); - - Self::from_buffer(geo_data) - } - - pub fn from_buffer(buffer: Bytes) -> Result { - let reader = maxminddb::Reader::from_source(buffer)?; - Ok(Self { - reader: Arc::new(reader), - }) - } -} - -impl GeoIpResolver for MaxMindResolver { - type Error = MaxMindResolverError; - - fn lookup_geo_data(&self, addr: IpAddr) -> Result { - let lookup_data = self.reader.lookup::(addr)?; - - Ok(GeoData { - continent: lookup_data - .continent - .and_then(|continent| continent.code.map(Into::into)), - country: lookup_data - .country - .and_then(|country| country.iso_code.map(Into::into)), - region: lookup_data.subdivisions.map(|divs| { - divs.into_iter() - .filter_map(|div| div.iso_code) - .map(Into::into) - .collect() - }), - city: lookup_data - .city - .and_then(|city| city.names) - .and_then(|city_names| city_names.get("en").copied().map(Into::into)), - }) - } -} diff --git a/examples/geoblock.rs b/examples/geoblock.rs index 1b833b1..02ae781 100644 --- a/examples/geoblock.rs +++ b/examples/geoblock.rs @@ -1,39 +1,40 @@ use { - geoblock::{geoip::GeoData, BlockingPolicy, GeoBlockLayer}, hyper::{Body, Request, Response, StatusCode}, - std::{ - convert::Infallible, - net::{IpAddr, Ipv4Addr}, - }, + std::{convert::Infallible, net::IpAddr}, tower::{Service, ServiceBuilder, ServiceExt}, - wc::geoblock::geoip::local::LocalResolver, + wc::geoip::{ + block::{middleware::GeoBlockLayer, BlockingPolicy}, + maxminddb::geoip2, + LocalResolver, + }, }; async fn handle(_request: Request) -> Result, Infallible> { Ok(Response::new(Body::empty())) } -fn resolve_ip(caller: IpAddr) -> GeoData { - if IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)) == caller { - GeoData { - continent: Some("NA".to_string().into()), - country: Some("CU".to_string().into()), - region: None, - city: None, - } - } else { - GeoData { - continent: Some("NA".to_string().into()), - country: Some("US".to_string().into()), - region: None, - city: None, - } +fn resolve_ip(_addr: IpAddr) -> geoip2::City<'static> { + geoip2::City { + city: None, + continent: None, + country: Some(geoip2::city::Country { + geoname_id: None, + is_in_european_union: None, + iso_code: Some("CU"), + names: None, + }), + location: None, + postal: None, + registered_country: None, + represented_country: None, + subdivisions: None, + traits: None, } } #[tokio::main] async fn main() -> Result<(), Box> { - let resolver: LocalResolver = LocalResolver::new(|caller| resolve_ip(caller)); + let resolver: LocalResolver = LocalResolver::new(Some(|caller| resolve_ip(caller)), None); let blocked_countries = vec!["CU".into(), "IR".into(), "KP".into()]; let geoblock = GeoBlockLayer::new(resolver, blocked_countries, BlockingPolicy::Block); diff --git a/src/lib.rs b/src/lib.rs index 48225a7..5c99951 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -7,8 +7,6 @@ pub use analytics; pub use collections; #[cfg(feature = "future")] pub use future; -#[cfg(feature = "geoblock")] -pub use geoblock; #[cfg(feature = "geoip")] pub use geoip; #[cfg(feature = "http")]