From fead698fb45c52aa981299ef90a47e934f4c43e5 Mon Sep 17 00:00:00 2001 From: Aumetra Weisman Date: Thu, 19 Dec 2024 11:01:38 +0100 Subject: [PATCH] use rfc-compliant granular errors --- Cargo.lock | 1 + lib/komainu/Cargo.toml | 1 + lib/komainu/benches/pkce.rs | 14 ++++----- lib/komainu/src/code_grant.rs | 2 +- lib/komainu/src/error.rs | 11 ------- lib/komainu/src/extract.rs | 7 ++--- lib/komainu/src/flow/authorization.rs | 31 +++++++++---------- lib/komainu/src/flow/mod.rs | 43 +++++++++++++++++++++++---- lib/komainu/src/flow/pkce.rs | 12 ++++---- lib/komainu/src/flow/refresh.rs | 16 +++++----- lib/komainu/src/lib.rs | 4 +-- lib/komainu/tests/pkce.rs | 6 ++-- 12 files changed, 85 insertions(+), 63 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 1d22735d8..4e95f9453 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3984,6 +3984,7 @@ dependencies = [ "divan", "headers", "http", + "insta", "memchr", "rstest", "serde", diff --git a/lib/komainu/Cargo.toml b/lib/komainu/Cargo.toml index 8f856749c..394f130c8 100644 --- a/lib/komainu/Cargo.toml +++ b/lib/komainu/Cargo.toml @@ -32,6 +32,7 @@ url.workspace = true [dev-dependencies] divan.workspace = true headers.workspace = true +insta.workspace = true rstest.workspace = true serde_test.workspace = true diff --git a/lib/komainu/benches/pkce.rs b/lib/komainu/benches/pkce.rs index 06bba6a51..1fac7c6b8 100644 --- a/lib/komainu/benches/pkce.rs +++ b/lib/komainu/benches/pkce.rs @@ -1,17 +1,17 @@ use divan::black_box; -use komainu::flow::{PkceMethod, PkcePayload}; +use komainu::flow::pkce; use std::borrow::Cow; #[global_allocator] static GLOBAL: divan::AllocProfiler = divan::AllocProfiler::system(); #[divan::bench] -fn s256() -> komainu::Result<()> { +fn s256() -> Result<(), komainu::flow::FlowError> { let verifier_base64 = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"; let challenge_base64 = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"; - let payload = PkcePayload { - method: black_box(PkceMethod::S256), + let payload = pkce::Payload { + method: black_box(pkce::Method::S256), challenge: black_box(Cow::Borrowed(challenge_base64)), }; @@ -19,11 +19,11 @@ fn s256() -> komainu::Result<()> { } #[divan::bench] -fn none() -> komainu::Result<()> { +fn none() -> Result<(), komainu::flow::FlowError> { let value = "arbitrary value"; - let payload = PkcePayload { - method: black_box(PkceMethod::None), + let payload = pkce::Payload { + method: black_box(pkce::Method::None), challenge: black_box(Cow::Borrowed(value)), }; diff --git a/lib/komainu/src/code_grant.rs b/lib/komainu/src/code_grant.rs index 902c7a21a..5347d41aa 100644 --- a/lib/komainu/src/code_grant.rs +++ b/lib/komainu/src/code_grant.rs @@ -72,7 +72,7 @@ where let client_id = query.get("client_id").or_invalid_request()?; let response_type = query.get("response_type").or_invalid_request()?; - let scope = query.get("scope").map(Deref::deref).unwrap_or(""); + let scope = query.get("scope").map_or("", Deref::deref); let state = query.get("state").map(|state| &**state); let client = self.client_extractor.extract(client_id, None).await?; diff --git a/lib/komainu/src/error.rs b/lib/komainu/src/error.rs index 18530524c..92b64d3e7 100644 --- a/lib/komainu/src/error.rs +++ b/lib/komainu/src/error.rs @@ -2,8 +2,6 @@ use thiserror::Error; type BoxError = Box; -pub type Result = std::result::Result; - #[derive(Debug, Error)] pub enum Error { #[error("Malformed body")] @@ -24,12 +22,3 @@ impl Error { Self::Query(err.into()) } } - -macro_rules! ensure { - ($cond:expr, $err:expr) => {{ - if !{ $cond } { - return Err($err); - } - }}; -} -pub(crate) use ensure; diff --git a/lib/komainu/src/extract.rs b/lib/komainu/src/extract.rs index b6da1d188..9350887c9 100644 --- a/lib/komainu/src/extract.rs +++ b/lib/komainu/src/extract.rs @@ -1,7 +1,4 @@ -use crate::{ - error::{Error, Result}, - params::ParamStorage, -}; +use crate::{error::Error, params::ParamStorage}; use bytes::Bytes; use memchr::memchr; @@ -9,7 +6,7 @@ static URL_ENCODED_CONTENT_TYPE: http::HeaderValue = http::HeaderValue::from_static("application/x-www-form-urlencoded"); #[inline] -pub fn body<'a, T>(req: &'a http::Request) -> Result +pub fn body<'a, T>(req: &'a http::Request) -> Result where T: serde::Deserialize<'a>, { diff --git a/lib/komainu/src/flow/authorization.rs b/lib/komainu/src/flow/authorization.rs index 1bd072c30..4e60f73dc 100644 --- a/lib/komainu/src/flow/authorization.rs +++ b/lib/komainu/src/flow/authorization.rs @@ -1,7 +1,6 @@ use crate::{ - error::{Error, Result}, extract::ClientCredentials, - flow::TokenResponse, + flow::{FlowError, OptionExt, TokenResponse}, params::ParamStorage, Authorization, ClientExtractor, }; @@ -12,12 +11,12 @@ pub trait Issuer { fn load_authorization( &self, auth_code: &str, - ) -> impl Future>>> + Send; + ) -> impl Future>, FlowError>> + Send; fn issue_token( &self, authorization: &Authorization<'_>, - ) -> impl Future>> + Send; + ) -> impl Future, FlowError>> + Send; } #[instrument(skip_all)] @@ -25,26 +24,27 @@ pub async fn perform( req: http::Request, client_extractor: CE, token_issuer: I, -) -> Result> +) -> Result, FlowError> where CE: ClientExtractor, I: Issuer, { let body: ParamStorage<&str, &str> = crate::extract::body(&req)?; - let client_credentials = ClientCredentials::extract(req.headers(), &body).or_unauthorized()?; + let client_credentials = + ClientCredentials::extract(req.headers(), &body).or_invalid_request()?; let (client_id, client_secret) = ( client_credentials.client_id(), client_credentials.client_secret(), ); - let grant_type = body.get("grant_type").or_missing_param()?; - let code = body.get("code").or_missing_param()?; - let redirect_uri = body.get("redirect_uri").or_missing_param()?; + let grant_type = body.get("grant_type").or_invalid_request()?; + let code = body.get("code").or_invalid_request()?; + let redirect_uri = body.get("redirect_uri").or_invalid_request()?; if *grant_type != "authorization_code" { error!(?client_id, "grant_type is not authorization_code"); - return Err(Error::Unauthorized); + return Err(FlowError::UnsupportedGrantType); } let client = client_extractor @@ -53,19 +53,20 @@ where if client.redirect_uri != *redirect_uri { error!(?client_id, "redirect uri doesn't match"); - return Err(Error::Unauthorized); + return Err(FlowError::InvalidClient); } - let maybe_authorization = token_issuer.load_authorization(code).await?; - let authorization = maybe_authorization.or_unauthorized()?; + let Some(authorization) = token_issuer.load_authorization(code).await? else { + return Err(FlowError::InvalidGrant); + }; // This check is constant time :3 if client != authorization.client { - return Err(Error::Unauthorized); + return Err(FlowError::UnauthorizedClient); } if let Some(ref pkce) = authorization.pkce_payload { - let code_verifier = body.get("code_verifier").or_unauthorized()?; + let code_verifier = body.get("code_verifier").or_invalid_request()?; pkce.verify(code_verifier)?; } diff --git a/lib/komainu/src/flow/mod.rs b/lib/komainu/src/flow/mod.rs index 8b5e7fa48..a5f10721b 100644 --- a/lib/komainu/src/flow/mod.rs +++ b/lib/komainu/src/flow/mod.rs @@ -1,10 +1,37 @@ use serde::Serialize; use std::borrow::Cow; +use strum::Display; +use thiserror::Error; pub mod authorization; pub mod pkce; pub mod refresh; +trait OptionExt { + fn or_invalid_request(self) -> Result; +} + +impl OptionExt for Option { + #[inline] + fn or_invalid_request(self) -> Result { + self.ok_or(FlowError::InvalidRequest) + } +} + +#[derive(Debug, Display, Error, Serialize)] +#[serde(rename_all = "snake_case")] +#[strum(serialize_all = "snake_case")] +pub enum FlowError { + InvalidRequest, + InvalidClient, + InvalidGrant, + UnauthorizedClient, + UnsupportedGrantType, + InvalidScope, + #[serde(skip)] + Other(#[from] crate::error::Error), +} + #[derive(Serialize)] #[serde(rename_all = "snake_case")] #[non_exhaustive] @@ -13,9 +40,15 @@ pub enum TokenType { } #[derive(Serialize)] -pub struct TokenResponse<'a> { - pub access_token: Cow<'a, str>, - pub token_type: TokenType, - pub refresh_token: Cow<'a, str>, - pub expires_in: u64, +#[serde(untagged)] +pub enum TokenResponse<'a> { + Success { + access_token: Cow<'a, str>, + token_type: TokenType, + refresh_token: Cow<'a, str>, + expires_in: u64, + }, + Error { + errorr: FlowError, + }, } diff --git a/lib/komainu/src/flow/pkce.rs b/lib/komainu/src/flow/pkce.rs index bb9af7300..ab18441cc 100644 --- a/lib/komainu/src/flow/pkce.rs +++ b/lib/komainu/src/flow/pkce.rs @@ -1,4 +1,4 @@ -use crate::error::Error; +use crate::{error::Error, flow::FlowError}; use serde::{Deserialize, Serialize}; use sha2::{Digest, Sha256}; use std::borrow::Cow; @@ -22,7 +22,7 @@ pub struct Payload<'a> { impl Payload<'_> { #[inline] - fn verify_s256(&self, code_verifier: &str) -> Result<()> { + fn verify_s256(&self, code_verifier: &str) -> Result<(), FlowError> { let decoded = base64_simd::URL_SAFE_NO_PAD .decode_to_vec(self.challenge.as_bytes()) .inspect_err(|error| debug!(?error, "failed to decode pkce payload")) @@ -32,22 +32,22 @@ impl Payload<'_> { if decoded.ct_eq(hash.as_slice()).into() { Ok(()) } else { - Err(Error::Unauthorized) + Err(FlowError::InvalidGrant) } } #[inline] - fn verify_none(&self, code_verifier: &str) -> Result<()> { + fn verify_none(&self, code_verifier: &str) -> Result<(), FlowError> { let challenge_bytes = self.challenge.as_bytes(); if challenge_bytes.ct_eq(code_verifier.as_bytes()).into() { Ok(()) } else { - Err(Error::Unauthorized) + Err(FlowError::InvalidGrant) } } #[inline] - pub fn verify(&self, code_verifier: &str) -> Result<()> { + pub fn verify(&self, code_verifier: &str) -> Result<(), FlowError> { match self.method { Method::None => self.verify_none(code_verifier), Method::S256 => self.verify_s256(code_verifier), diff --git a/lib/komainu/src/flow/refresh.rs b/lib/komainu/src/flow/refresh.rs index 1e6a744b5..74c580bb8 100644 --- a/lib/komainu/src/flow/refresh.rs +++ b/lib/komainu/src/flow/refresh.rs @@ -1,7 +1,6 @@ -use super::TokenResponse; use crate::{ - error::{Error, Result}, extract::ClientCredentials, + flow::{FlowError, OptionExt, TokenResponse}, params::ParamStorage, Client, ClientExtractor, }; @@ -13,7 +12,7 @@ pub trait Issuer { &self, client: &Client<'_>, refresh_token: &str, - ) -> impl Future>> + Send; + ) -> impl Future, FlowError>> + Send; } #[instrument(skip_all)] @@ -21,25 +20,26 @@ pub async fn perform( req: http::Request, client_extractor: CE, token_issuer: I, -) -> Result> +) -> Result, FlowError> where CE: ClientExtractor, I: Issuer, { let body: ParamStorage<&str, &str> = crate::extract::body(&req)?; - let client_credentials = ClientCredentials::extract(req.headers(), &body).or_unauthorized()?; + let client_credentials = + ClientCredentials::extract(req.headers(), &body).or_invalid_request()?; let (client_id, client_secret) = ( client_credentials.client_id(), client_credentials.client_secret(), ); - let grant_type = body.get("grant_type").or_missing_param()?; - let refresh_token = body.get("refresh_token").or_missing_param()?; + let grant_type = body.get("grant_type").or_invalid_request()?; + let refresh_token = body.get("refresh_token").or_invalid_request()?; if *grant_type != "refresh_token" { debug!(?client_id, "grant_type is not refresh_token"); - return Err(Error::Unauthorized); + return Err(FlowError::UnsupportedGrantType); } let client = client_extractor diff --git a/lib/komainu/src/lib.rs b/lib/komainu/src/lib.rs index 0fa0fd129..cca160468 100644 --- a/lib/komainu/src/lib.rs +++ b/lib/komainu/src/lib.rs @@ -5,7 +5,7 @@ use self::flow::pkce; use std::{borrow::Cow, future::Future}; use subtle::ConstantTimeEq; -pub use self::error::{Error, Result}; +pub use self::error::Error; pub use self::params::ParamStorage; mod error; @@ -54,5 +54,5 @@ pub trait ClientExtractor { &self, client_id: &str, client_secret: Option<&str>, - ) -> impl Future>> + Send; + ) -> impl Future, Error>> + Send; } diff --git a/lib/komainu/tests/pkce.rs b/lib/komainu/tests/pkce.rs index 300811b91..b93d4e088 100644 --- a/lib/komainu/tests/pkce.rs +++ b/lib/komainu/tests/pkce.rs @@ -1,4 +1,4 @@ -use komainu::flow::{PkceMethod, PkcePayload}; +use komainu::flow::pkce; use std::borrow::Cow; #[test] @@ -25,8 +25,8 @@ fn verify_rfc_payload_s256() { challenge_base64 ); - let payload = PkcePayload { - method: PkceMethod::S256, + let payload = pkce::Payload { + method: pkce::Method::S256, challenge: Cow::Borrowed(challenge_base64), }; payload.verify(verifier_base64).unwrap();