diff --git a/lib/komainu/src/authorize.rs b/lib/komainu/src/authorize.rs index bd6cd1ea..214f6f65 100644 --- a/lib/komainu/src/authorize.rs +++ b/lib/komainu/src/authorize.rs @@ -1,9 +1,10 @@ use crate::{ error::{Error, Result}, + flow::{PkceMethod, PkcePayload}, params::ParamStorage, - Authorization, Client, ClientExtractor, OAuthError, OptionExt, + Authorization, Client, ClientExtractor, OAuthError, OptionExt, PreAuthorization, }; -use std::{borrow::Borrow, collections::HashSet, future::Future}; +use std::{borrow::Borrow, collections::HashSet, future::Future, str::FromStr}; pub trait Issuer { type UserId; @@ -11,8 +12,7 @@ pub trait Issuer { fn issue_code( &self, user_id: Self::UserId, - client_id: &str, - scopes: &[&str], + pre_authorization: PreAuthorization<'_>, ) -> impl Future>> + Send; } @@ -75,9 +75,22 @@ where return Err(Error::Unauthorized); } + let pkce_payload = if let Some(challenge) = query.get("code_challenge") { + let method = if let Some(method) = query.get("challenge_code_method") { + PkceMethod::from_str(*method).map_err(Error::query)? + } else { + PkceMethod::default() + }; + + Some(PkcePayload { method, challenge }) + } else { + None + }; + Ok(Authorizer { issuer: &self.issuer, client, + pkce_payload, query, state, }) @@ -87,6 +100,7 @@ where pub struct Authorizer<'a, I> { issuer: &'a I, client: Client<'a>, + pkce_payload: Option>, query: ParamStorage<&'a str, &'a str>, state: Option<&'a str>, } @@ -120,9 +134,15 @@ where #[inline] #[instrument(skip_all)] pub async fn accept(self, user_id: I::UserId, scopes: &[&str]) -> http::Response<()> { + let pre_authorization = PreAuthorization { + client: self.client, + scopes, + pkce_payload: self.pkce_payload, + }; + let code = self .issuer - .issue_code(user_id, self.client.client_id, scopes) + .issue_code(user_id, pre_authorization) .await .unwrap(); diff --git a/lib/komainu/src/flow/authorization.rs b/lib/komainu/src/flow/authorization.rs index 910b6e64..fc03092b 100644 --- a/lib/komainu/src/flow/authorization.rs +++ b/lib/komainu/src/flow/authorization.rs @@ -75,7 +75,10 @@ where return Err(Error::Unauthorized); } - // TODO: Verify PKCE challenge + if let Some(ref pkce) = authorization.pkce_payload { + let code_verifier = body.get("code_verifier").or_unauthorized()?; + pkce.verify(code_verifier)?; + } let token = token_issuer.issue_token(&authorization).await?; let body = sonic_rs::to_vec(&token).unwrap(); diff --git a/lib/komainu/src/flow/mod.rs b/lib/komainu/src/flow/mod.rs index 04316d79..b1b24b20 100644 --- a/lib/komainu/src/flow/mod.rs +++ b/lib/komainu/src/flow/mod.rs @@ -23,9 +23,10 @@ pub struct TokenResponse<'a> { pub expires_in: u64, } -#[derive(AsRefStr, Deserialize, EnumString, Serialize)] +#[derive(AsRefStr, Default, Deserialize, EnumString, Serialize)] #[strum(serialize_all = "snake_case")] pub enum PkceMethod { + #[default] None, #[strum(serialize = "S256")] S256, diff --git a/lib/komainu/src/lib.rs b/lib/komainu/src/lib.rs index 62aedf2b..23aa2a17 100644 --- a/lib/komainu/src/lib.rs +++ b/lib/komainu/src/lib.rs @@ -37,7 +37,14 @@ impl OptionExt for Option { pub struct Authorization<'a> { pub code: Cow<'a, str>, pub client: Client<'a>, - pub pkce_payload: PkcePayload<'a>, + pub pkce_payload: Option>, + pub scopes: Cow<'a, [Cow<'a, str>]>, +} + +pub struct PreAuthorization<'a> { + pub client: Client<'a>, + pub pkce_payload: Option>, + pub scopes: &'a [&'a str], } pub struct Client<'a> {