Skip to content

Commit

Permalink
use rfc-compliant granular errors
Browse files Browse the repository at this point in the history
  • Loading branch information
aumetra committed Dec 19, 2024
1 parent 26023e8 commit fead698
Show file tree
Hide file tree
Showing 12 changed files with 85 additions and 63 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions lib/komainu/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ url.workspace = true
[dev-dependencies]
divan.workspace = true
headers.workspace = true
insta.workspace = true
rstest.workspace = true
serde_test.workspace = true

Expand Down
14 changes: 7 additions & 7 deletions lib/komainu/benches/pkce.rs
Original file line number Diff line number Diff line change
@@ -1,29 +1,29 @@
use divan::black_box;
use komainu::flow::{PkceMethod, PkcePayload};
use komainu::flow::pkce;
use std::borrow::Cow;

#[global_allocator]
static GLOBAL: divan::AllocProfiler = divan::AllocProfiler::system();

#[divan::bench]
fn s256() -> komainu::Result<()> {
fn s256() -> Result<(), komainu::flow::FlowError> {
let verifier_base64 = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk";
let challenge_base64 = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM";

let payload = PkcePayload {
method: black_box(PkceMethod::S256),
let payload = pkce::Payload {
method: black_box(pkce::Method::S256),
challenge: black_box(Cow::Borrowed(challenge_base64)),
};

payload.verify(black_box(verifier_base64))
}

#[divan::bench]
fn none() -> komainu::Result<()> {
fn none() -> Result<(), komainu::flow::FlowError> {
let value = "arbitrary value";

let payload = PkcePayload {
method: black_box(PkceMethod::None),
let payload = pkce::Payload {
method: black_box(pkce::Method::None),
challenge: black_box(Cow::Borrowed(value)),
};

Expand Down
2 changes: 1 addition & 1 deletion lib/komainu/src/code_grant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ where

let client_id = query.get("client_id").or_invalid_request()?;
let response_type = query.get("response_type").or_invalid_request()?;
let scope = query.get("scope").map(Deref::deref).unwrap_or("");
let scope = query.get("scope").map_or("", Deref::deref);
let state = query.get("state").map(|state| &**state);

let client = self.client_extractor.extract(client_id, None).await?;
Expand Down
11 changes: 0 additions & 11 deletions lib/komainu/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ use thiserror::Error;

type BoxError = Box<dyn std::error::Error + Send + Sync>;

pub type Result<T, E = Error> = std::result::Result<T, E>;

#[derive(Debug, Error)]
pub enum Error {
#[error("Malformed body")]
Expand All @@ -24,12 +22,3 @@ impl Error {
Self::Query(err.into())
}
}

macro_rules! ensure {
($cond:expr, $err:expr) => {{
if !{ $cond } {
return Err($err);
}
}};
}
pub(crate) use ensure;
7 changes: 2 additions & 5 deletions lib/komainu/src/extract.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
use crate::{
error::{Error, Result},
params::ParamStorage,
};
use crate::{error::Error, params::ParamStorage};
use bytes::Bytes;
use memchr::memchr;

static URL_ENCODED_CONTENT_TYPE: http::HeaderValue =
http::HeaderValue::from_static("application/x-www-form-urlencoded");

#[inline]
pub fn body<'a, T>(req: &'a http::Request<Bytes>) -> Result<T>
pub fn body<'a, T>(req: &'a http::Request<Bytes>) -> Result<T, Error>
where
T: serde::Deserialize<'a>,
{
Expand Down
31 changes: 16 additions & 15 deletions lib/komainu/src/flow/authorization.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use crate::{
error::{Error, Result},
extract::ClientCredentials,
flow::TokenResponse,
flow::{FlowError, OptionExt, TokenResponse},
params::ParamStorage,
Authorization, ClientExtractor,
};
Expand All @@ -12,39 +11,40 @@ pub trait Issuer {
fn load_authorization(
&self,
auth_code: &str,
) -> impl Future<Output = Result<Option<Authorization<'_>>>> + Send;
) -> impl Future<Output = Result<Option<Authorization<'_>>, FlowError>> + Send;

fn issue_token(
&self,
authorization: &Authorization<'_>,
) -> impl Future<Output = Result<TokenResponse<'_>>> + Send;
) -> impl Future<Output = Result<TokenResponse<'_>, FlowError>> + Send;
}

#[instrument(skip_all)]
pub async fn perform<CE, I>(
req: http::Request<Bytes>,
client_extractor: CE,
token_issuer: I,
) -> Result<http::Response<Bytes>>
) -> Result<http::Response<Bytes>, FlowError>
where
CE: ClientExtractor,
I: Issuer,
{
let body: ParamStorage<&str, &str> = crate::extract::body(&req)?;
let client_credentials = ClientCredentials::extract(req.headers(), &body).or_unauthorized()?;
let client_credentials =
ClientCredentials::extract(req.headers(), &body).or_invalid_request()?;

let (client_id, client_secret) = (
client_credentials.client_id(),
client_credentials.client_secret(),
);

let grant_type = body.get("grant_type").or_missing_param()?;
let code = body.get("code").or_missing_param()?;
let redirect_uri = body.get("redirect_uri").or_missing_param()?;
let grant_type = body.get("grant_type").or_invalid_request()?;
let code = body.get("code").or_invalid_request()?;
let redirect_uri = body.get("redirect_uri").or_invalid_request()?;

if *grant_type != "authorization_code" {
error!(?client_id, "grant_type is not authorization_code");
return Err(Error::Unauthorized);
return Err(FlowError::UnsupportedGrantType);
}

let client = client_extractor
Expand All @@ -53,19 +53,20 @@ where

if client.redirect_uri != *redirect_uri {
error!(?client_id, "redirect uri doesn't match");
return Err(Error::Unauthorized);
return Err(FlowError::InvalidClient);
}

let maybe_authorization = token_issuer.load_authorization(code).await?;
let authorization = maybe_authorization.or_unauthorized()?;
let Some(authorization) = token_issuer.load_authorization(code).await? else {
return Err(FlowError::InvalidGrant);
};

// This check is constant time :3
if client != authorization.client {
return Err(Error::Unauthorized);
return Err(FlowError::UnauthorizedClient);
}

if let Some(ref pkce) = authorization.pkce_payload {
let code_verifier = body.get("code_verifier").or_unauthorized()?;
let code_verifier = body.get("code_verifier").or_invalid_request()?;
pkce.verify(code_verifier)?;
}

Expand Down
43 changes: 38 additions & 5 deletions lib/komainu/src/flow/mod.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,37 @@
use serde::Serialize;
use std::borrow::Cow;
use strum::Display;
use thiserror::Error;

pub mod authorization;
pub mod pkce;
pub mod refresh;

trait OptionExt<T> {
fn or_invalid_request(self) -> Result<T, FlowError>;
}

impl<T> OptionExt<T> for Option<T> {
#[inline]
fn or_invalid_request(self) -> Result<T, FlowError> {
self.ok_or(FlowError::InvalidRequest)
}
}

#[derive(Debug, Display, Error, Serialize)]
#[serde(rename_all = "snake_case")]
#[strum(serialize_all = "snake_case")]
pub enum FlowError {
InvalidRequest,
InvalidClient,
InvalidGrant,
UnauthorizedClient,
UnsupportedGrantType,
InvalidScope,
#[serde(skip)]
Other(#[from] crate::error::Error),
}

#[derive(Serialize)]
#[serde(rename_all = "snake_case")]
#[non_exhaustive]
Expand All @@ -13,9 +40,15 @@ pub enum TokenType {
}

#[derive(Serialize)]
pub struct TokenResponse<'a> {
pub access_token: Cow<'a, str>,
pub token_type: TokenType,
pub refresh_token: Cow<'a, str>,
pub expires_in: u64,
#[serde(untagged)]
pub enum TokenResponse<'a> {
Success {
access_token: Cow<'a, str>,
token_type: TokenType,
refresh_token: Cow<'a, str>,
expires_in: u64,
},
Error {
errorr: FlowError,

Check warning on line 52 in lib/komainu/src/flow/mod.rs

View workflow job for this annotation

GitHub Actions / Spell-check repository source

"errorr" should be "error".
},
}
12 changes: 6 additions & 6 deletions lib/komainu/src/flow/pkce.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::error::Error;
use crate::{error::Error, flow::FlowError};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::borrow::Cow;
Expand All @@ -22,7 +22,7 @@ pub struct Payload<'a> {

impl Payload<'_> {
#[inline]
fn verify_s256(&self, code_verifier: &str) -> Result<()> {
fn verify_s256(&self, code_verifier: &str) -> Result<(), FlowError> {
let decoded = base64_simd::URL_SAFE_NO_PAD
.decode_to_vec(self.challenge.as_bytes())
.inspect_err(|error| debug!(?error, "failed to decode pkce payload"))
Expand All @@ -32,22 +32,22 @@ impl Payload<'_> {
if decoded.ct_eq(hash.as_slice()).into() {
Ok(())
} else {
Err(Error::Unauthorized)
Err(FlowError::InvalidGrant)
}
}

#[inline]
fn verify_none(&self, code_verifier: &str) -> Result<()> {
fn verify_none(&self, code_verifier: &str) -> Result<(), FlowError> {
let challenge_bytes = self.challenge.as_bytes();
if challenge_bytes.ct_eq(code_verifier.as_bytes()).into() {
Ok(())
} else {
Err(Error::Unauthorized)
Err(FlowError::InvalidGrant)
}
}

#[inline]
pub fn verify(&self, code_verifier: &str) -> Result<()> {
pub fn verify(&self, code_verifier: &str) -> Result<(), FlowError> {
match self.method {
Method::None => self.verify_none(code_verifier),
Method::S256 => self.verify_s256(code_verifier),
Expand Down
16 changes: 8 additions & 8 deletions lib/komainu/src/flow/refresh.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use super::TokenResponse;
use crate::{
error::{Error, Result},
extract::ClientCredentials,
flow::{FlowError, OptionExt, TokenResponse},
params::ParamStorage,
Client, ClientExtractor,
};
Expand All @@ -13,33 +12,34 @@ pub trait Issuer {
&self,
client: &Client<'_>,
refresh_token: &str,
) -> impl Future<Output = Result<TokenResponse<'_>>> + Send;
) -> impl Future<Output = Result<TokenResponse<'_>, FlowError>> + Send;
}

#[instrument(skip_all)]
pub async fn perform<CE, I>(
req: http::Request<Bytes>,
client_extractor: CE,
token_issuer: I,
) -> Result<http::Response<Bytes>>
) -> Result<http::Response<Bytes>, FlowError>
where
CE: ClientExtractor,
I: Issuer,
{
let body: ParamStorage<&str, &str> = crate::extract::body(&req)?;
let client_credentials = ClientCredentials::extract(req.headers(), &body).or_unauthorized()?;
let client_credentials =
ClientCredentials::extract(req.headers(), &body).or_invalid_request()?;

let (client_id, client_secret) = (
client_credentials.client_id(),
client_credentials.client_secret(),
);

let grant_type = body.get("grant_type").or_missing_param()?;
let refresh_token = body.get("refresh_token").or_missing_param()?;
let grant_type = body.get("grant_type").or_invalid_request()?;
let refresh_token = body.get("refresh_token").or_invalid_request()?;

if *grant_type != "refresh_token" {
debug!(?client_id, "grant_type is not refresh_token");
return Err(Error::Unauthorized);
return Err(FlowError::UnsupportedGrantType);
}

let client = client_extractor
Expand Down
4 changes: 2 additions & 2 deletions lib/komainu/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use self::flow::pkce;
use std::{borrow::Cow, future::Future};
use subtle::ConstantTimeEq;

pub use self::error::{Error, Result};
pub use self::error::Error;
pub use self::params::ParamStorage;

mod error;
Expand Down Expand Up @@ -54,5 +54,5 @@ pub trait ClientExtractor {
&self,
client_id: &str,
client_secret: Option<&str>,
) -> impl Future<Output = Result<Client<'_>>> + Send;
) -> impl Future<Output = Result<Client<'_>, Error>> + Send;
}
6 changes: 3 additions & 3 deletions lib/komainu/tests/pkce.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use komainu::flow::{PkceMethod, PkcePayload};
use komainu::flow::pkce;
use std::borrow::Cow;

#[test]
Expand All @@ -25,8 +25,8 @@ fn verify_rfc_payload_s256() {
challenge_base64
);

let payload = PkcePayload {
method: PkceMethod::S256,
let payload = pkce::Payload {
method: pkce::Method::S256,
challenge: Cow::Borrowed(challenge_base64),
};
payload.verify(verifier_base64).unwrap();
Expand Down

0 comments on commit fead698

Please sign in to comment.