diff --git a/lib/komainu/benches/basic_auth.rs b/lib/komainu/benches/basic_auth.rs index 0d622566..cd9c7a6f 100644 --- a/lib/komainu/benches/basic_auth.rs +++ b/lib/komainu/benches/basic_auth.rs @@ -25,7 +25,7 @@ mod headers { #[divan::bench_group] mod ours { use divan::{black_box, black_box_drop}; - use komainu::extractor::BasicAuth; + use komainu::extract::BasicAuth; #[divan::bench] fn rfc_value(b: divan::Bencher<'_, '_>) { diff --git a/lib/komainu/src/authorize.rs b/lib/komainu/src/authorize.rs index 08df0633..8070d9cc 100644 --- a/lib/komainu/src/authorize.rs +++ b/lib/komainu/src/authorize.rs @@ -97,6 +97,15 @@ where } } +macro_rules! return_err { + ($result:expr) => {{ + match { $result } { + Ok(val) => val, + Err(err) => return err, + } + }}; +} + pub struct Authorizer<'a, I> { issuer: &'a I, client: Client<'a>, @@ -131,6 +140,18 @@ where .unwrap() } + #[inline] + fn redirect_uri(&self) -> Result> { + url::Url::parse(&self.client.redirect_uri).map_err(|error| { + error!(?error, redirect_uri = ?self.client.redirect_uri, "invalid redirect uri"); + + http::Response::builder() + .status(http::StatusCode::INTERNAL_SERVER_ERROR) + .body(()) + .unwrap() + }) + } + #[inline] #[instrument(skip_all)] pub async fn accept(self, user_id: I::UserId, scopes: &[&str]) -> http::Response<()> { @@ -140,13 +161,19 @@ where pkce_payload: self.pkce_payload.as_ref(), }; - let code = self - .issuer - .issue_code(user_id, pre_authorization) - .await - .unwrap(); + let mut url = return_err!(self.redirect_uri()); + + let code = match self.issuer.issue_code(user_id, pre_authorization).await { + Ok(code) => code, + Err(error) => { + debug!(?error, "failed to issue code"); + url.query_pairs_mut() + .append_pair("error", OAuthError::TemporarilyUnavailable.as_ref()); + + return Self::build_response(url); + } + }; - let mut url = url::Url::parse(&self.client.redirect_uri).unwrap(); url.query_pairs_mut().append_pair("code", &code); if let Some(state) = self.state { @@ -160,7 +187,7 @@ where #[must_use] #[instrument(skip_all)] pub fn deny(self) -> http::Response<()> { - let mut url = url::Url::parse(&self.client.redirect_uri).unwrap(); + let mut url = return_err!(self.redirect_uri()); url.query_pairs_mut() .append_pair("error", OAuthError::AccessDenied.as_ref()); diff --git a/lib/komainu/src/error.rs b/lib/komainu/src/error.rs index 4a754624..cfda21f9 100644 --- a/lib/komainu/src/error.rs +++ b/lib/komainu/src/error.rs @@ -33,18 +33,6 @@ impl Error { } } -impl From for OAuthError { - #[track_caller] - fn from(value: Error) -> Self { - debug!(error = ?value); - - match value { - Error::Body(..) | Error::MissingParam | Error::Query(..) => Self::InvalidRequest, - Error::Unauthorized => Self::AccessDenied, - } - } -} - #[derive(AsRefStr, Serialize)] #[serde(rename_all = "snake_case")] #[strum(serialize_all = "snake_case")] @@ -62,31 +50,3 @@ pub enum OAuthError { pub struct OAuthErrorResponse { pub error: OAuthError, } - -macro_rules! fallible { - ($op:expr) => {{ - match { $op } { - Ok(val) => val, - Err(error) => { - debug!(?error); - $crate::error::yield_error!(error); - } - } - }}; -} - -macro_rules! yield_error { - (@ser $error:expr) => {{ - return ::http::Response::builder() - .status(::http::StatusCode::BAD_REQUEST) - .body(sonic_rs::to_vec(&$error).unwrap().into()) - .unwrap(); - }}; - ($error:expr) => {{ - $crate::error::yield_error!(@ser $crate::error::OAuthErrorResponse { - error: $error.into(), - }); - }}; -} - -pub(crate) use {fallible, yield_error}; diff --git a/lib/komainu/src/extractor.rs b/lib/komainu/src/extract.rs similarity index 100% rename from lib/komainu/src/extractor.rs rename to lib/komainu/src/extract.rs diff --git a/lib/komainu/src/flow/authorization.rs b/lib/komainu/src/flow/authorization.rs index 303dd48c..cd8949ff 100644 --- a/lib/komainu/src/flow/authorization.rs +++ b/lib/komainu/src/flow/authorization.rs @@ -1,9 +1,9 @@ use super::TokenResponse; use crate::{ - error::{fallible, yield_error, Result}, - extractor::ClientCredentials, + error::{Error, Result}, + extract::ClientCredentials, params::ParamStorage, - Authorization, ClientExtractor, Error, OptionExt, + Authorization, ClientExtractor, OptionExt, }; use bytes::Bytes; use std::future::Future; @@ -25,61 +25,59 @@ pub async fn perform( req: http::Request, client_extractor: CE, token_issuer: I, -) -> http::Response +) -> Result> where CE: ClientExtractor, I: Issuer, { - let body: ParamStorage<&str, &str> = fallible!(crate::extractor::body(&req)); - - let client_credentials = - fallible!(ClientCredentials::extract(req.headers(), &body).or_unauthorized()); + let body: ParamStorage<&str, &str> = crate::extract::body(&req)?; + let client_credentials = ClientCredentials::extract(req.headers(), &body).or_unauthorized()?; let (client_id, client_secret) = ( client_credentials.client_id(), client_credentials.client_secret(), ); - let grant_type = fallible!(body.get("grant_type").or_missing_param()); - let code = fallible!(body.get("code").or_missing_param()); - let redirect_uri = fallible!(body.get("redirect_uri").or_missing_param()); + let grant_type = body.get("grant_type").or_missing_param()?; + let code = body.get("code").or_missing_param()?; + let redirect_uri = body.get("redirect_uri").or_missing_param()?; if *grant_type != "authorization_code" { error!(?client_id, "grant_type is not authorization_code"); - yield_error!(Error::Unauthorized); + return Err(Error::Unauthorized); } - let client = fallible!( - client_extractor - .extract(client_id, Some(client_secret)) - .await - ); + let client = client_extractor + .extract(client_id, Some(client_secret)) + .await?; if client.redirect_uri != *redirect_uri { error!(?client_id, "redirect uri doesn't match"); - yield_error!(Error::Unauthorized); + return Err(Error::Unauthorized); } - let maybe_authorization = fallible!(token_issuer.load_authorization(code).await); - let authorization = fallible!(maybe_authorization.or_unauthorized()); + let maybe_authorization = token_issuer.load_authorization(code).await?; + let authorization = maybe_authorization.or_unauthorized()?; // This check is constant time :3 if client != authorization.client { - yield_error!(Error::Unauthorized); + return Err(Error::Unauthorized); } if let Some(ref pkce) = authorization.pkce_payload { - let code_verifier = fallible!(body.get("code_verifier").or_unauthorized()); - fallible!(pkce.verify(code_verifier)); + let code_verifier = body.get("code_verifier").or_unauthorized()?; + pkce.verify(code_verifier)?; } - let token = fallible!(token_issuer.issue_token(&authorization).await); + let token = token_issuer.issue_token(&authorization).await?; let body = sonic_rs::to_vec(&token).unwrap(); debug!("token successfully issued. building response"); - http::Response::builder() + let response = http::Response::builder() .status(http::StatusCode::OK) .body(body.into()) - .unwrap() + .unwrap(); + + Ok(response) } diff --git a/lib/komainu/src/flow/refresh.rs b/lib/komainu/src/flow/refresh.rs index 1c22b0ff..1cce4ffc 100644 --- a/lib/komainu/src/flow/refresh.rs +++ b/lib/komainu/src/flow/refresh.rs @@ -1,7 +1,7 @@ use super::TokenResponse; use crate::{ - error::{fallible, yield_error, Error, Result}, - extractor::ClientCredentials, + error::{Error, Result}, + extract::ClientCredentials, params::ParamStorage, Client, ClientExtractor, OptionExt, }; @@ -21,42 +21,40 @@ pub async fn perform( req: http::Request, client_extractor: CE, token_issuer: I, -) -> http::Response +) -> Result> where CE: ClientExtractor, I: Issuer, { - let body: ParamStorage<&str, &str> = fallible!(crate::extractor::body(&req)); - - let client_credentials = - fallible!(ClientCredentials::extract(req.headers(), &body).or_unauthorized()); + let body: ParamStorage<&str, &str> = crate::extract::body(&req)?; + let client_credentials = ClientCredentials::extract(req.headers(), &body).or_unauthorized()?; let (client_id, client_secret) = ( client_credentials.client_id(), client_credentials.client_secret(), ); - let grant_type = fallible!(body.get("grant_type").or_missing_param()); - let refresh_token = fallible!(body.get("refresh_token").or_missing_param()); + let grant_type = body.get("grant_type").or_missing_param()?; + let refresh_token = body.get("refresh_token").or_missing_param()?; if *grant_type != "refresh_token" { debug!(?client_id, "grant_type is not refresh_token"); - yield_error!(Error::Unauthorized); + return Err(Error::Unauthorized); } - let client = fallible!( - client_extractor - .extract(client_id, Some(client_secret)) - .await - ); + let client = client_extractor + .extract(client_id, Some(client_secret)) + .await?; - let token = fallible!(token_issuer.issue_token(&client, refresh_token).await); + let token = token_issuer.issue_token(&client, refresh_token).await?; let body = sonic_rs::to_vec(&token).unwrap(); debug!("token successfully issued. building response"); - http::Response::builder() + let response = http::Response::builder() .status(http::StatusCode::OK) .body(body.into()) - .unwrap() + .unwrap(); + + Ok(response) } diff --git a/lib/komainu/src/lib.rs b/lib/komainu/src/lib.rs index d98996d3..b66109a8 100644 --- a/lib/komainu/src/lib.rs +++ b/lib/komainu/src/lib.rs @@ -11,7 +11,7 @@ pub use self::params::ParamStorage; mod error; pub mod authorize; -pub mod extractor; +pub mod extract; pub mod flow; pub mod params; diff --git a/lib/komainu/tests/basic_auth.rs b/lib/komainu/tests/basic_auth.rs index 2ccb06d1..cb4c2f3c 100644 --- a/lib/komainu/tests/basic_auth.rs +++ b/lib/komainu/tests/basic_auth.rs @@ -1,4 +1,4 @@ -use komainu::extractor::BasicAuth; +use komainu::extract::BasicAuth; use rstest::rstest; #[test]