Skip to content

Commit

Permalink
move around, change errors
Browse files Browse the repository at this point in the history
  • Loading branch information
aumetra committed Dec 17, 2024
1 parent d2198f6 commit b48719c
Show file tree
Hide file tree
Showing 7 changed files with 122 additions and 137 deletions.
87 changes: 56 additions & 31 deletions lib/komainu/src/authorize.rs → lib/komainu/src/code_grant.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,47 @@
use crate::{
error::{Error, OAuthError, Result},
flow::{PkceMethod, PkcePayload},
params::ParamStorage,
Client, ClientExtractor, OptionExt, PreAuthorization,
error::Error, flow::pkce, params::ParamStorage, AuthInstruction, Client, ClientExtractor,
};
use std::{
borrow::{Borrow, Cow},
collections::HashSet,
future::Future,
str::FromStr,
};
use strum::{AsRefStr, Display};
use thiserror::Error;

trait OptionExt<T> {
fn or_invalid_request(self) -> Result<T, GrantError>;
}

impl<T> OptionExt<T> for Option<T> {
#[inline]
fn or_invalid_request(self) -> Result<T, GrantError> {
self.ok_or(GrantError::InvalidRequest)
}
}

#[derive(AsRefStr, Debug, Display, Error)]
#[strum(serialize_all = "snake_case")]
pub enum GrantError {
InvalidRequest,
UnauthorizedClient,
AccessDenied,
UnsupportedResponseType,
InvalidScope,
ServerError,
TemporarilyUnavailable,
Other(#[from] Error),
}

pub trait Issuer {
type UserId;

fn issue_code(
&self,
user_id: Self::UserId,
pre_authorization: PreAuthorization<'_, '_>,
) -> impl Future<Output = Result<String>> + Send;
pre_authorization: AuthInstruction<'_, '_>,
) -> impl Future<Output = Result<String, GrantError>> + Send;
}

pub struct AuthorizerExtractor<I, CE> {
Expand All @@ -38,26 +61,29 @@ where
}

#[instrument(skip_all)]
pub async fn extract<'a>(&'a self, req: &'a http::Request<()>) -> Result<Authorizer<'a, I>> {
pub async fn extract<'a>(
&'a self,
req: &'a http::Request<()>,
) -> Result<Authorizer<'a, I>, GrantError> {
let query: ParamStorage<&str, &str> =
serde_urlencoded::from_str(req.uri().query().or_missing_param()?)
serde_urlencoded::from_str(req.uri().query().or_invalid_request()?)
.map_err(Error::query)?;

let client_id = query.get("client_id").or_missing_param()?;
let response_type = query.get("response_type").or_missing_param()?;
let client_id = query.get("client_id").or_invalid_request()?;
let response_type = query.get("response_type").or_invalid_request()?;
if *response_type != "code" {
debug!(?client_id, "response_type not set to \"code\"");
return Err(Error::Unauthorized);
return Err(GrantError::AccessDenied);
}

let scope = query.get("scope").or_missing_param()?;
let redirect_uri = query.get("redirect_uri").or_missing_param()?;
let scope = query.get("scope").or_invalid_request()?;
let redirect_uri = query.get("redirect_uri").or_invalid_request()?;
let state = query.get("state").map(|state| &**state);

let client = self.client_extractor.extract(client_id, None).await?;
if client.redirect_uri != *redirect_uri {
debug!(?client_id, "redirect uri doesn't match");
return Err(Error::Unauthorized);
return Err(GrantError::AccessDenied);
}

let request_scopes = scope.split_whitespace().collect::<HashSet<_>>();
Expand All @@ -69,17 +95,17 @@ where

if !request_scopes.is_subset(&client_scopes) {
debug!(?client_id, "scopes aren't a subset");
return Err(Error::Unauthorized);
return Err(GrantError::AccessDenied);
}

let pkce_payload = if let Some(challenge) = query.get("code_challenge") {
let method = if let Some(method) = query.get("challenge_code_method") {
PkceMethod::from_str(method).map_err(Error::query)?
pkce::Method::from_str(method).map_err(Error::query)?
} else {
PkceMethod::default()
pkce::Method::default()
};

Some(PkcePayload {
Some(pkce::Payload {
method,
challenge: Cow::Borrowed(challenge),
})
Expand Down Expand Up @@ -109,7 +135,7 @@ macro_rules! return_err {
pub struct Authorizer<'a, I> {
issuer: &'a I,
client: Client<'a>,
pkce_payload: Option<PkcePayload<'a>>,
pkce_payload: Option<pkce::Payload<'a>>,
query: ParamStorage<&'a str, &'a str>,
state: Option<&'a str>,
}
Expand Down Expand Up @@ -152,28 +178,31 @@ where
})
}

#[inline]
fn build_error_response(&self, error: &GrantError) -> http::Response<()> {
let mut uri = return_err!(self.redirect_uri());
uri.query_pairs_mut().append_pair("error", error.as_ref());
Self::build_response(uri)
}

#[inline]
#[instrument(skip_all)]
pub async fn accept(self, user_id: I::UserId, scopes: &[&str]) -> http::Response<()> {
let pre_authorization = PreAuthorization {
let pre_authorization = AuthInstruction {
client: &self.client,
scopes,
pkce_payload: self.pkce_payload.as_ref(),
};

let mut url = return_err!(self.redirect_uri());

let code = match self.issuer.issue_code(user_id, pre_authorization).await {
Ok(code) => code,
Err(error) => {
debug!(?error, "failed to issue code");
url.query_pairs_mut()
.append_pair("error", OAuthError::TemporarilyUnavailable.as_ref());

return Self::build_response(url);
return self.build_error_response(&GrantError::TemporarilyUnavailable);
}
};

let mut url = return_err!(self.redirect_uri());
url.query_pairs_mut().append_pair("code", &code);

if let Some(state) = self.state {
Expand All @@ -187,10 +216,6 @@ where
#[must_use]
#[instrument(skip_all)]
pub fn deny(self) -> http::Response<()> {
let mut url = return_err!(self.redirect_uri());
url.query_pairs_mut()
.append_pair("error", OAuthError::AccessDenied.as_ref());

Self::build_response(url)
self.build_error_response(&GrantError::AccessDenied)
}
}
26 changes: 0 additions & 26 deletions lib/komainu/src/error.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use serde::Serialize;
use strum::AsRefStr;
use thiserror::Error;

type BoxError = Box<dyn std::error::Error + Send + Sync>;
Expand All @@ -11,14 +9,8 @@ pub enum Error {
#[error("Malformed body")]
Body(#[source] BoxError),

#[error("Missing parameter")]
MissingParam,

#[error("Malformed query")]
Query(#[source] BoxError),

#[error("Request is unauthorized")]
Unauthorized,
}

impl Error {
Expand All @@ -32,21 +24,3 @@ impl Error {
Self::Query(err.into())
}
}

#[derive(AsRefStr, Serialize)]
#[serde(rename_all = "snake_case")]
#[strum(serialize_all = "snake_case")]
pub enum OAuthError {
InvalidRequest,
UnauthorizedClient,
AccessDenied,
UnsupportedResponseType,
InvalidScope,
ServerError,
TemporarilyUnavailable,
}

#[derive(Serialize)]
pub struct OAuthErrorResponse {
pub error: OAuthError,
}
4 changes: 2 additions & 2 deletions lib/komainu/src/flow/authorization.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use super::TokenResponse;
use crate::{
error::{Error, Result},
extract::ClientCredentials,
flow::TokenResponse,
params::ParamStorage,
Authorization, ClientExtractor, OptionExt,
Authorization, ClientExtractor,
};
use bytes::Bytes;
use std::future::Future;
Expand Down
57 changes: 2 additions & 55 deletions lib/komainu/src/flow/mod.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
use crate::{Error, Result};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use serde::Serialize;
use std::borrow::Cow;
use strum::{AsRefStr, EnumString};
use subtle::ConstantTimeEq;

pub mod authorization;
pub mod pkce;
pub mod refresh;

#[derive(Serialize)]
Expand All @@ -22,53 +19,3 @@ pub struct TokenResponse<'a> {
pub refresh_token: Cow<'a, str>,
pub expires_in: u64,
}

#[derive(AsRefStr, Default, Deserialize, EnumString, Serialize)]
#[strum(serialize_all = "snake_case")]
pub enum PkceMethod {
#[default]
None,
#[strum(serialize = "S256")]
S256,
}

#[derive(Deserialize, Serialize)]
pub struct PkcePayload<'a> {
pub challenge: Cow<'a, str>,
pub method: PkceMethod,
}

impl PkcePayload<'_> {
#[inline]
fn verify_s256(&self, code_verifier: &str) -> Result<()> {
let decoded = base64_simd::URL_SAFE_NO_PAD
.decode_to_vec(self.challenge.as_bytes())
.inspect_err(|error| debug!(?error, "failed to decode pkce payload"))
.map_err(Error::body)?;

let hash = Sha256::digest(code_verifier);
if decoded.ct_eq(hash.as_slice()).into() {
Ok(())
} else {
Err(Error::Unauthorized)
}
}

#[inline]
fn verify_none(&self, code_verifier: &str) -> Result<()> {
let challenge_bytes = self.challenge.as_bytes();
if challenge_bytes.ct_eq(code_verifier.as_bytes()).into() {
Ok(())
} else {
Err(Error::Unauthorized)
}
}

#[inline]
pub fn verify(&self, code_verifier: &str) -> Result<()> {
match self.method {
PkceMethod::None => self.verify_none(code_verifier),
PkceMethod::S256 => self.verify_s256(code_verifier),
}
}
}
56 changes: 56 additions & 0 deletions lib/komainu/src/flow/pkce.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
use crate::error::Error;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::borrow::Cow;
use strum::{AsRefStr, EnumString};
use subtle::ConstantTimeEq;

#[derive(AsRefStr, Default, Deserialize, EnumString, Serialize)]
#[strum(serialize_all = "snake_case")]
pub enum Method {
#[default]
None,
#[strum(serialize = "S256")]
S256,
}

#[derive(Deserialize, Serialize)]
pub struct Payload<'a> {
pub challenge: Cow<'a, str>,
pub method: Method,
}

impl Payload<'_> {
#[inline]
fn verify_s256(&self, code_verifier: &str) -> Result<()> {
let decoded = base64_simd::URL_SAFE_NO_PAD
.decode_to_vec(self.challenge.as_bytes())
.inspect_err(|error| debug!(?error, "failed to decode pkce payload"))
.map_err(Error::body)?;

let hash = Sha256::digest(code_verifier);
if decoded.ct_eq(hash.as_slice()).into() {
Ok(())
} else {
Err(Error::Unauthorized)
}
}

#[inline]
fn verify_none(&self, code_verifier: &str) -> Result<()> {
let challenge_bytes = self.challenge.as_bytes();
if challenge_bytes.ct_eq(code_verifier.as_bytes()).into() {
Ok(())
} else {
Err(Error::Unauthorized)
}
}

#[inline]
pub fn verify(&self, code_verifier: &str) -> Result<()> {
match self.method {
Method::None => self.verify_none(code_verifier),
Method::S256 => self.verify_s256(code_verifier),
}
}
}
2 changes: 1 addition & 1 deletion lib/komainu/src/flow/refresh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::{
error::{Error, Result},
extract::ClientCredentials,
params::ParamStorage,
Client, ClientExtractor, OptionExt,
Client, ClientExtractor,
};
use bytes::Bytes;
use std::future::Future;
Expand Down
Loading

0 comments on commit b48719c

Please sign in to comment.