Skip to content

Commit

Permalink
rename, revert stuff, fix error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
aumetra committed Dec 17, 2024
1 parent c37cd4b commit d2198f6
Show file tree
Hide file tree
Showing 8 changed files with 77 additions and 94 deletions.
2 changes: 1 addition & 1 deletion lib/komainu/benches/basic_auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ mod headers {
#[divan::bench_group]
mod ours {
use divan::{black_box, black_box_drop};
use komainu::extractor::BasicAuth;
use komainu::extract::BasicAuth;

#[divan::bench]
fn rfc_value(b: divan::Bencher<'_, '_>) {
Expand Down
41 changes: 34 additions & 7 deletions lib/komainu/src/authorize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,15 @@ where
}
}

macro_rules! return_err {
($result:expr) => {{
match { $result } {
Ok(val) => val,
Err(err) => return err,
}
}};
}

pub struct Authorizer<'a, I> {
issuer: &'a I,
client: Client<'a>,
Expand Down Expand Up @@ -131,6 +140,18 @@ where
.unwrap()
}

#[inline]
fn redirect_uri(&self) -> Result<url::Url, http::Response<()>> {
url::Url::parse(&self.client.redirect_uri).map_err(|error| {
error!(?error, redirect_uri = ?self.client.redirect_uri, "invalid redirect uri");

http::Response::builder()
.status(http::StatusCode::INTERNAL_SERVER_ERROR)
.body(())
.unwrap()
})
}

#[inline]
#[instrument(skip_all)]
pub async fn accept(self, user_id: I::UserId, scopes: &[&str]) -> http::Response<()> {
Expand All @@ -140,13 +161,19 @@ where
pkce_payload: self.pkce_payload.as_ref(),
};

let code = self
.issuer
.issue_code(user_id, pre_authorization)
.await
.unwrap();
let mut url = return_err!(self.redirect_uri());

let code = match self.issuer.issue_code(user_id, pre_authorization).await {
Ok(code) => code,
Err(error) => {
debug!(?error, "failed to issue code");
url.query_pairs_mut()
.append_pair("error", OAuthError::TemporarilyUnavailable.as_ref());

return Self::build_response(url);
}
};

let mut url = url::Url::parse(&self.client.redirect_uri).unwrap();
url.query_pairs_mut().append_pair("code", &code);

if let Some(state) = self.state {
Expand All @@ -160,7 +187,7 @@ where
#[must_use]
#[instrument(skip_all)]
pub fn deny(self) -> http::Response<()> {
let mut url = url::Url::parse(&self.client.redirect_uri).unwrap();
let mut url = return_err!(self.redirect_uri());
url.query_pairs_mut()
.append_pair("error", OAuthError::AccessDenied.as_ref());

Expand Down
40 changes: 0 additions & 40 deletions lib/komainu/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,6 @@ impl Error {
}
}

impl From<Error> for OAuthError {
#[track_caller]
fn from(value: Error) -> Self {
debug!(error = ?value);

match value {
Error::Body(..) | Error::MissingParam | Error::Query(..) => Self::InvalidRequest,
Error::Unauthorized => Self::AccessDenied,
}
}
}

#[derive(AsRefStr, Serialize)]
#[serde(rename_all = "snake_case")]
#[strum(serialize_all = "snake_case")]
Expand All @@ -62,31 +50,3 @@ pub enum OAuthError {
pub struct OAuthErrorResponse {
pub error: OAuthError,
}

macro_rules! fallible {
($op:expr) => {{
match { $op } {
Ok(val) => val,
Err(error) => {
debug!(?error);
$crate::error::yield_error!(error);
}
}
}};
}

macro_rules! yield_error {
(@ser $error:expr) => {{
return ::http::Response::builder()
.status(::http::StatusCode::BAD_REQUEST)
.body(sonic_rs::to_vec(&$error).unwrap().into())
.unwrap();
}};
($error:expr) => {{
$crate::error::yield_error!(@ser $crate::error::OAuthErrorResponse {
error: $error.into(),
});
}};
}

pub(crate) use {fallible, yield_error};
File renamed without changes.
50 changes: 24 additions & 26 deletions lib/komainu/src/flow/authorization.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use super::TokenResponse;
use crate::{
error::{fallible, yield_error, Result},
extractor::ClientCredentials,
error::{Error, Result},
extract::ClientCredentials,
params::ParamStorage,
Authorization, ClientExtractor, Error, OptionExt,
Authorization, ClientExtractor, OptionExt,
};
use bytes::Bytes;
use std::future::Future;
Expand All @@ -25,61 +25,59 @@ pub async fn perform<CE, I>(
req: http::Request<Bytes>,
client_extractor: CE,
token_issuer: I,
) -> http::Response<Bytes>
) -> Result<http::Response<Bytes>>
where
CE: ClientExtractor,
I: Issuer,
{
let body: ParamStorage<&str, &str> = fallible!(crate::extractor::body(&req));

let client_credentials =
fallible!(ClientCredentials::extract(req.headers(), &body).or_unauthorized());
let body: ParamStorage<&str, &str> = crate::extract::body(&req)?;
let client_credentials = ClientCredentials::extract(req.headers(), &body).or_unauthorized()?;

let (client_id, client_secret) = (
client_credentials.client_id(),
client_credentials.client_secret(),
);

let grant_type = fallible!(body.get("grant_type").or_missing_param());
let code = fallible!(body.get("code").or_missing_param());
let redirect_uri = fallible!(body.get("redirect_uri").or_missing_param());
let grant_type = body.get("grant_type").or_missing_param()?;
let code = body.get("code").or_missing_param()?;
let redirect_uri = body.get("redirect_uri").or_missing_param()?;

if *grant_type != "authorization_code" {
error!(?client_id, "grant_type is not authorization_code");
yield_error!(Error::Unauthorized);
return Err(Error::Unauthorized);
}

let client = fallible!(
client_extractor
.extract(client_id, Some(client_secret))
.await
);
let client = client_extractor
.extract(client_id, Some(client_secret))
.await?;

if client.redirect_uri != *redirect_uri {
error!(?client_id, "redirect uri doesn't match");
yield_error!(Error::Unauthorized);
return Err(Error::Unauthorized);
}

let maybe_authorization = fallible!(token_issuer.load_authorization(code).await);
let authorization = fallible!(maybe_authorization.or_unauthorized());
let maybe_authorization = token_issuer.load_authorization(code).await?;
let authorization = maybe_authorization.or_unauthorized()?;

// This check is constant time :3
if client != authorization.client {
yield_error!(Error::Unauthorized);
return Err(Error::Unauthorized);
}

if let Some(ref pkce) = authorization.pkce_payload {
let code_verifier = fallible!(body.get("code_verifier").or_unauthorized());
fallible!(pkce.verify(code_verifier));
let code_verifier = body.get("code_verifier").or_unauthorized()?;
pkce.verify(code_verifier)?;
}

let token = fallible!(token_issuer.issue_token(&authorization).await);
let token = token_issuer.issue_token(&authorization).await?;
let body = sonic_rs::to_vec(&token).unwrap();

debug!("token successfully issued. building response");

http::Response::builder()
let response = http::Response::builder()
.status(http::StatusCode::OK)
.body(body.into())
.unwrap()
.unwrap();

Ok(response)
}
34 changes: 16 additions & 18 deletions lib/komainu/src/flow/refresh.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use super::TokenResponse;
use crate::{
error::{fallible, yield_error, Error, Result},
extractor::ClientCredentials,
error::{Error, Result},
extract::ClientCredentials,
params::ParamStorage,
Client, ClientExtractor, OptionExt,
};
Expand All @@ -21,42 +21,40 @@ pub async fn perform<CE, I>(
req: http::Request<Bytes>,
client_extractor: CE,
token_issuer: I,
) -> http::Response<Bytes>
) -> Result<http::Response<Bytes>>
where
CE: ClientExtractor,
I: Issuer,
{
let body: ParamStorage<&str, &str> = fallible!(crate::extractor::body(&req));

let client_credentials =
fallible!(ClientCredentials::extract(req.headers(), &body).or_unauthorized());
let body: ParamStorage<&str, &str> = crate::extract::body(&req)?;
let client_credentials = ClientCredentials::extract(req.headers(), &body).or_unauthorized()?;

let (client_id, client_secret) = (
client_credentials.client_id(),
client_credentials.client_secret(),
);

let grant_type = fallible!(body.get("grant_type").or_missing_param());
let refresh_token = fallible!(body.get("refresh_token").or_missing_param());
let grant_type = body.get("grant_type").or_missing_param()?;
let refresh_token = body.get("refresh_token").or_missing_param()?;

if *grant_type != "refresh_token" {
debug!(?client_id, "grant_type is not refresh_token");
yield_error!(Error::Unauthorized);
return Err(Error::Unauthorized);
}

let client = fallible!(
client_extractor
.extract(client_id, Some(client_secret))
.await
);
let client = client_extractor
.extract(client_id, Some(client_secret))
.await?;

let token = fallible!(token_issuer.issue_token(&client, refresh_token).await);
let token = token_issuer.issue_token(&client, refresh_token).await?;
let body = sonic_rs::to_vec(&token).unwrap();

debug!("token successfully issued. building response");

http::Response::builder()
let response = http::Response::builder()
.status(http::StatusCode::OK)
.body(body.into())
.unwrap()
.unwrap();

Ok(response)
}
2 changes: 1 addition & 1 deletion lib/komainu/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ pub use self::params::ParamStorage;
mod error;

pub mod authorize;
pub mod extractor;
pub mod extract;
pub mod flow;
pub mod params;

Expand Down
2 changes: 1 addition & 1 deletion lib/komainu/tests/basic_auth.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use komainu::extractor::BasicAuth;
use komainu::extract::BasicAuth;
use rstest::rstest;

#[test]
Expand Down

0 comments on commit d2198f6

Please sign in to comment.