From 11f195839173e99b0a6861493c227b571240f7ad Mon Sep 17 00:00:00 2001 From: aumetra Date: Fri, 13 Dec 2024 22:28:10 +0100 Subject: [PATCH] add pkce verifier --- Cargo.lock | 3 ++ lib/komainu/Cargo.toml | 3 ++ lib/komainu/src/flow/authorization.rs | 9 ++++- lib/komainu/src/flow/mod.rs | 57 ++++++++++++++++++++++++++- lib/komainu/src/flow/refresh.rs | 2 +- 5 files changed, 71 insertions(+), 3 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9f4b17227..5c647fe32 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3993,14 +3993,17 @@ dependencies = [ name = "komainu" version = "0.0.1-pre.6" dependencies = [ + "base64-simd", "bytes", "headers", "http", "serde", "serde_test", "serde_urlencoded", + "sha2", "sonic-rs", "strum", + "subtle", "thiserror 2.0.6", "tracing", "url", diff --git a/lib/komainu/Cargo.toml b/lib/komainu/Cargo.toml index 40311450e..c4b99084f 100644 --- a/lib/komainu/Cargo.toml +++ b/lib/komainu/Cargo.toml @@ -6,13 +6,16 @@ version.workspace = true license = "MIT OR Apache-2.0" [dependencies] +base64-simd.workspace = true bytes.workspace = true headers.workspace = true http.workspace = true serde.workspace = true serde_urlencoded.workspace = true +sha2.workspace = true sonic-rs.workspace = true strum.workspace = true +subtle.workspace = true thiserror.workspace = true tracing.workspace = true url.workspace = true diff --git a/lib/komainu/src/flow/authorization.rs b/lib/komainu/src/flow/authorization.rs index ce631b6d7..74cb2efc4 100644 --- a/lib/komainu/src/flow/authorization.rs +++ b/lib/komainu/src/flow/authorization.rs @@ -1,10 +1,17 @@ -use super::TokenResponse; +use super::{PkcePayload, TokenResponse}; use crate::{error::Result, params::ParamStorage, Client, ClientExtractor, Error, OptionExt}; use bytes::Bytes; use headers::HeaderMapExt; use std::future::Future; +pub struct Authorization<'a> { + pub client: Client<'a>, + pub pkce: PkcePayload<'a>, +} + pub trait Issuer { + fn load_authorization(&self, ); + fn issue_token( &self, client: &Client<'_>, diff --git a/lib/komainu/src/flow/mod.rs b/lib/komainu/src/flow/mod.rs index ee70f1f56..04316d798 100644 --- a/lib/komainu/src/flow/mod.rs +++ b/lib/komainu/src/flow/mod.rs @@ -1,5 +1,9 @@ -use serde::Serialize; +use crate::{Error, Result}; +use serde::{Deserialize, Serialize}; +use sha2::{Digest, Sha256}; use std::borrow::Cow; +use strum::{AsRefStr, EnumString}; +use subtle::ConstantTimeEq; pub mod authorization; pub mod refresh; @@ -18,3 +22,54 @@ pub struct TokenResponse<'a> { pub refresh_token: Cow<'a, str>, pub expires_in: u64, } + +#[derive(AsRefStr, Deserialize, EnumString, Serialize)] +#[strum(serialize_all = "snake_case")] +pub enum PkceMethod { + None, + #[strum(serialize = "S256")] + S256, +} + +#[derive(Deserialize, Serialize)] +pub struct PkcePayload<'a> { + pub challenge: Cow<'a, str>, + pub method: PkceMethod, +} + +impl PkcePayload<'_> { + #[inline] + fn verify_s256(&self, code_verifier: &str) -> Result<()> { + let decoded = base64_simd::URL_SAFE + .decode_to_vec(code_verifier) + .inspect_err(|error| debug!(?error, "failed to decode pkce payload")) + .map_err(Error::body)?; + + let hash = Sha256::digest(code_verifier); + + if decoded.ct_eq(hash.as_slice()).into() { + Ok(()) + } else { + Err(Error::Unauthorized) + } + } + + #[inline] + fn verify_none(&self, code_verifier: &str) -> Result<()> { + let challenge_bytes = self.challenge.as_bytes(); + + if challenge_bytes.ct_eq(code_verifier.as_bytes()).into() { + Ok(()) + } else { + Err(Error::Unauthorized) + } + } + + #[inline] + pub fn verify(&self, code_verifier: &str) -> Result<()> { + match self.method { + PkceMethod::None => self.verify_none(code_verifier), + PkceMethod::S256 => self.verify_s256(code_verifier), + } + } +} diff --git a/lib/komainu/src/flow/refresh.rs b/lib/komainu/src/flow/refresh.rs index 2a1a47a03..13f4d3352 100644 --- a/lib/komainu/src/flow/refresh.rs +++ b/lib/komainu/src/flow/refresh.rs @@ -1,8 +1,8 @@ use super::TokenResponse; use crate::{ error::{Error, Result}, - Client, ClientExtractor, OptionExt, params::ParamStorage, + Client, ClientExtractor, OptionExt, }; use bytes::Bytes; use headers::HeaderMapExt;