diff --git a/lib/komainu/src/authorize.rs b/lib/komainu/src/code_grant.rs similarity index 65% rename from lib/komainu/src/authorize.rs rename to lib/komainu/src/code_grant.rs index 8070d9cc..9d681bd8 100644 --- a/lib/komainu/src/authorize.rs +++ b/lib/komainu/src/code_grant.rs @@ -1,8 +1,5 @@ use crate::{ - error::{Error, OAuthError, Result}, - flow::{PkceMethod, PkcePayload}, - params::ParamStorage, - Client, ClientExtractor, OptionExt, PreAuthorization, + error::Error, flow::pkce, params::ParamStorage, AuthInstruction, Client, ClientExtractor, }; use std::{ borrow::{Borrow, Cow}, @@ -10,6 +7,32 @@ use std::{ future::Future, str::FromStr, }; +use strum::{AsRefStr, Display}; +use thiserror::Error; + +trait OptionExt { + fn or_invalid_request(self) -> Result; +} + +impl OptionExt for Option { + #[inline] + fn or_invalid_request(self) -> Result { + self.ok_or(GrantError::InvalidRequest) + } +} + +#[derive(AsRefStr, Debug, Display, Error)] +#[strum(serialize_all = "snake_case")] +pub enum GrantError { + InvalidRequest, + UnauthorizedClient, + AccessDenied, + UnsupportedResponseType, + InvalidScope, + ServerError, + TemporarilyUnavailable, + Other(#[from] Error), +} pub trait Issuer { type UserId; @@ -17,8 +40,8 @@ pub trait Issuer { fn issue_code( &self, user_id: Self::UserId, - pre_authorization: PreAuthorization<'_, '_>, - ) -> impl Future> + Send; + pre_authorization: AuthInstruction<'_, '_>, + ) -> impl Future> + Send; } pub struct AuthorizerExtractor { @@ -38,26 +61,29 @@ where } #[instrument(skip_all)] - pub async fn extract<'a>(&'a self, req: &'a http::Request<()>) -> Result> { + pub async fn extract<'a>( + &'a self, + req: &'a http::Request<()>, + ) -> Result, GrantError> { let query: ParamStorage<&str, &str> = - serde_urlencoded::from_str(req.uri().query().or_missing_param()?) + serde_urlencoded::from_str(req.uri().query().or_invalid_request()?) .map_err(Error::query)?; - let client_id = query.get("client_id").or_missing_param()?; - let response_type = query.get("response_type").or_missing_param()?; + let client_id = query.get("client_id").or_invalid_request()?; + let response_type = query.get("response_type").or_invalid_request()?; if *response_type != "code" { debug!(?client_id, "response_type not set to \"code\""); - return Err(Error::Unauthorized); + return Err(GrantError::AccessDenied); } - let scope = query.get("scope").or_missing_param()?; - let redirect_uri = query.get("redirect_uri").or_missing_param()?; + let scope = query.get("scope").or_invalid_request()?; + let redirect_uri = query.get("redirect_uri").or_invalid_request()?; let state = query.get("state").map(|state| &**state); let client = self.client_extractor.extract(client_id, None).await?; if client.redirect_uri != *redirect_uri { debug!(?client_id, "redirect uri doesn't match"); - return Err(Error::Unauthorized); + return Err(GrantError::AccessDenied); } let request_scopes = scope.split_whitespace().collect::>(); @@ -69,17 +95,17 @@ where if !request_scopes.is_subset(&client_scopes) { debug!(?client_id, "scopes aren't a subset"); - return Err(Error::Unauthorized); + return Err(GrantError::AccessDenied); } let pkce_payload = if let Some(challenge) = query.get("code_challenge") { let method = if let Some(method) = query.get("challenge_code_method") { - PkceMethod::from_str(method).map_err(Error::query)? + pkce::Method::from_str(method).map_err(Error::query)? } else { - PkceMethod::default() + pkce::Method::default() }; - Some(PkcePayload { + Some(pkce::Payload { method, challenge: Cow::Borrowed(challenge), }) @@ -109,7 +135,7 @@ macro_rules! return_err { pub struct Authorizer<'a, I> { issuer: &'a I, client: Client<'a>, - pkce_payload: Option>, + pkce_payload: Option>, query: ParamStorage<&'a str, &'a str>, state: Option<&'a str>, } @@ -152,28 +178,31 @@ where }) } + #[inline] + fn build_error_response(&self, error: &GrantError) -> http::Response<()> { + let mut uri = return_err!(self.redirect_uri()); + uri.query_pairs_mut().append_pair("error", error.as_ref()); + Self::build_response(uri) + } + #[inline] #[instrument(skip_all)] pub async fn accept(self, user_id: I::UserId, scopes: &[&str]) -> http::Response<()> { - let pre_authorization = PreAuthorization { + let pre_authorization = AuthInstruction { client: &self.client, scopes, pkce_payload: self.pkce_payload.as_ref(), }; - let mut url = return_err!(self.redirect_uri()); - let code = match self.issuer.issue_code(user_id, pre_authorization).await { Ok(code) => code, Err(error) => { debug!(?error, "failed to issue code"); - url.query_pairs_mut() - .append_pair("error", OAuthError::TemporarilyUnavailable.as_ref()); - - return Self::build_response(url); + return self.build_error_response(&GrantError::TemporarilyUnavailable); } }; + let mut url = return_err!(self.redirect_uri()); url.query_pairs_mut().append_pair("code", &code); if let Some(state) = self.state { @@ -187,10 +216,6 @@ where #[must_use] #[instrument(skip_all)] pub fn deny(self) -> http::Response<()> { - let mut url = return_err!(self.redirect_uri()); - url.query_pairs_mut() - .append_pair("error", OAuthError::AccessDenied.as_ref()); - - Self::build_response(url) + self.build_error_response(&GrantError::AccessDenied) } } diff --git a/lib/komainu/src/error.rs b/lib/komainu/src/error.rs index cfda21f9..789c9c02 100644 --- a/lib/komainu/src/error.rs +++ b/lib/komainu/src/error.rs @@ -1,5 +1,3 @@ -use serde::Serialize; -use strum::AsRefStr; use thiserror::Error; type BoxError = Box; @@ -11,14 +9,8 @@ pub enum Error { #[error("Malformed body")] Body(#[source] BoxError), - #[error("Missing parameter")] - MissingParam, - #[error("Malformed query")] Query(#[source] BoxError), - - #[error("Request is unauthorized")] - Unauthorized, } impl Error { @@ -32,21 +24,3 @@ impl Error { Self::Query(err.into()) } } - -#[derive(AsRefStr, Serialize)] -#[serde(rename_all = "snake_case")] -#[strum(serialize_all = "snake_case")] -pub enum OAuthError { - InvalidRequest, - UnauthorizedClient, - AccessDenied, - UnsupportedResponseType, - InvalidScope, - ServerError, - TemporarilyUnavailable, -} - -#[derive(Serialize)] -pub struct OAuthErrorResponse { - pub error: OAuthError, -} diff --git a/lib/komainu/src/flow/authorization.rs b/lib/komainu/src/flow/authorization.rs index cd8949ff..1bd072c3 100644 --- a/lib/komainu/src/flow/authorization.rs +++ b/lib/komainu/src/flow/authorization.rs @@ -1,9 +1,9 @@ -use super::TokenResponse; use crate::{ error::{Error, Result}, extract::ClientCredentials, + flow::TokenResponse, params::ParamStorage, - Authorization, ClientExtractor, OptionExt, + Authorization, ClientExtractor, }; use bytes::Bytes; use std::future::Future; diff --git a/lib/komainu/src/flow/mod.rs b/lib/komainu/src/flow/mod.rs index 807c571f..8b5e7fa4 100644 --- a/lib/komainu/src/flow/mod.rs +++ b/lib/komainu/src/flow/mod.rs @@ -1,11 +1,8 @@ -use crate::{Error, Result}; -use serde::{Deserialize, Serialize}; -use sha2::{Digest, Sha256}; +use serde::Serialize; use std::borrow::Cow; -use strum::{AsRefStr, EnumString}; -use subtle::ConstantTimeEq; pub mod authorization; +pub mod pkce; pub mod refresh; #[derive(Serialize)] @@ -22,53 +19,3 @@ pub struct TokenResponse<'a> { pub refresh_token: Cow<'a, str>, pub expires_in: u64, } - -#[derive(AsRefStr, Default, Deserialize, EnumString, Serialize)] -#[strum(serialize_all = "snake_case")] -pub enum PkceMethod { - #[default] - None, - #[strum(serialize = "S256")] - S256, -} - -#[derive(Deserialize, Serialize)] -pub struct PkcePayload<'a> { - pub challenge: Cow<'a, str>, - pub method: PkceMethod, -} - -impl PkcePayload<'_> { - #[inline] - fn verify_s256(&self, code_verifier: &str) -> Result<()> { - let decoded = base64_simd::URL_SAFE_NO_PAD - .decode_to_vec(self.challenge.as_bytes()) - .inspect_err(|error| debug!(?error, "failed to decode pkce payload")) - .map_err(Error::body)?; - - let hash = Sha256::digest(code_verifier); - if decoded.ct_eq(hash.as_slice()).into() { - Ok(()) - } else { - Err(Error::Unauthorized) - } - } - - #[inline] - fn verify_none(&self, code_verifier: &str) -> Result<()> { - let challenge_bytes = self.challenge.as_bytes(); - if challenge_bytes.ct_eq(code_verifier.as_bytes()).into() { - Ok(()) - } else { - Err(Error::Unauthorized) - } - } - - #[inline] - pub fn verify(&self, code_verifier: &str) -> Result<()> { - match self.method { - PkceMethod::None => self.verify_none(code_verifier), - PkceMethod::S256 => self.verify_s256(code_verifier), - } - } -} diff --git a/lib/komainu/src/flow/pkce.rs b/lib/komainu/src/flow/pkce.rs new file mode 100644 index 00000000..bb9af730 --- /dev/null +++ b/lib/komainu/src/flow/pkce.rs @@ -0,0 +1,56 @@ +use crate::error::Error; +use serde::{Deserialize, Serialize}; +use sha2::{Digest, Sha256}; +use std::borrow::Cow; +use strum::{AsRefStr, EnumString}; +use subtle::ConstantTimeEq; + +#[derive(AsRefStr, Default, Deserialize, EnumString, Serialize)] +#[strum(serialize_all = "snake_case")] +pub enum Method { + #[default] + None, + #[strum(serialize = "S256")] + S256, +} + +#[derive(Deserialize, Serialize)] +pub struct Payload<'a> { + pub challenge: Cow<'a, str>, + pub method: Method, +} + +impl Payload<'_> { + #[inline] + fn verify_s256(&self, code_verifier: &str) -> Result<()> { + let decoded = base64_simd::URL_SAFE_NO_PAD + .decode_to_vec(self.challenge.as_bytes()) + .inspect_err(|error| debug!(?error, "failed to decode pkce payload")) + .map_err(Error::body)?; + + let hash = Sha256::digest(code_verifier); + if decoded.ct_eq(hash.as_slice()).into() { + Ok(()) + } else { + Err(Error::Unauthorized) + } + } + + #[inline] + fn verify_none(&self, code_verifier: &str) -> Result<()> { + let challenge_bytes = self.challenge.as_bytes(); + if challenge_bytes.ct_eq(code_verifier.as_bytes()).into() { + Ok(()) + } else { + Err(Error::Unauthorized) + } + } + + #[inline] + pub fn verify(&self, code_verifier: &str) -> Result<()> { + match self.method { + Method::None => self.verify_none(code_verifier), + Method::S256 => self.verify_s256(code_verifier), + } + } +} diff --git a/lib/komainu/src/flow/refresh.rs b/lib/komainu/src/flow/refresh.rs index 1cce4ffc..1e6a744b 100644 --- a/lib/komainu/src/flow/refresh.rs +++ b/lib/komainu/src/flow/refresh.rs @@ -3,7 +3,7 @@ use crate::{ error::{Error, Result}, extract::ClientCredentials, params::ParamStorage, - Client, ClientExtractor, OptionExt, + Client, ClientExtractor, }; use bytes::Bytes; use std::future::Future; diff --git a/lib/komainu/src/lib.rs b/lib/komainu/src/lib.rs index b66109a8..0fa0fd12 100644 --- a/lib/komainu/src/lib.rs +++ b/lib/komainu/src/lib.rs @@ -1,7 +1,7 @@ #[macro_use] extern crate tracing; -use self::flow::PkcePayload; +use self::flow::pkce; use std::{borrow::Cow, future::Future}; use subtle::ConstantTimeEq; @@ -10,38 +10,21 @@ pub use self::params::ParamStorage; mod error; -pub mod authorize; +pub mod code_grant; pub mod extract; pub mod flow; pub mod params; -trait OptionExt { - fn or_missing_param(self) -> Result; - fn or_unauthorized(self) -> Result; -} - -impl OptionExt for Option { - #[inline] - fn or_missing_param(self) -> Result { - self.ok_or(Error::MissingParam) - } - - #[inline] - fn or_unauthorized(self) -> Result { - self.ok_or(Error::Unauthorized) - } -} - pub struct Authorization<'a> { pub code: Cow<'a, str>, pub client: Client<'a>, - pub pkce_payload: Option>, + pub pkce_payload: Option>, pub scopes: Cow<'a, [Cow<'a, str>]>, } -pub struct PreAuthorization<'a, 'b> { +pub struct AuthInstruction<'a, 'b> { pub client: &'b Client<'a>, - pub pkce_payload: Option<&'b PkcePayload<'a>>, + pub pkce_payload: Option<&'b pkce::Payload<'a>>, pub scopes: &'b [&'b str], }