From 51b6f918b792fe1708b9613cc7d46efc553b55e6 Mon Sep 17 00:00:00 2001 From: Aumetra Weisman Date: Mon, 8 Apr 2024 00:51:04 +0200 Subject: [PATCH] use try-block polyfills everywhere --- Cargo.lock | 5 +- crates/kitsune-db/Cargo.toml | 3 +- crates/kitsune-db/src/error.rs | 23 --- crates/kitsune-db/src/lib.rs | 6 +- crates/kitsune-db/src/pool.rs | 15 +- kitsune-cli/Cargo.toml | 1 + kitsune-cli/src/main.rs | 3 +- kitsune-job-runner/src/main.rs | 5 +- kitsune/Cargo.toml | 1 + kitsune/src/http/extractor/auth.rs | 23 ++- kitsune/src/main.rs | 1 + kitsune/src/oauth2/authorizer.rs | 94 ++++----- kitsune/src/oauth2/issuer.rs | 302 +++++++++++++++-------------- kitsune/src/oauth2/registrar.rs | 59 +++--- kitsune/src/oauth2/solicitor.rs | 33 ++-- 15 files changed, 292 insertions(+), 282 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 28ae90280..359fcaee7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3620,6 +3620,7 @@ dependencies = [ "tower-stop-using-brave", "tower-x-clacks-overhead", "tracing", + "trials", "typed-builder", "url", "utoipa", @@ -3715,6 +3716,7 @@ dependencies = [ "envy", "kitsune-config", "kitsune-db", + "kitsune-error", "serde", "speedy-uuid", "tokio", @@ -3760,6 +3762,7 @@ dependencies = [ "futures-util", "iso8601-timestamp", "kitsune-config", + "kitsune-error", "kitsune-language", "kitsune-test", "kitsune-type", @@ -3770,12 +3773,12 @@ dependencies = [ "serde", "simd-json", "speedy-uuid", - "thiserror", "tokio", "tokio-postgres", "tokio-postgres-rustls", "tracing", "tracing-log", + "trials", "typed-builder", ] diff --git a/crates/kitsune-db/Cargo.toml b/crates/kitsune-db/Cargo.toml index 57de69478..eb4534549 100644 --- a/crates/kitsune-db/Cargo.toml +++ b/crates/kitsune-db/Cargo.toml @@ -22,6 +22,7 @@ futures-util = { version = "0.3.30", default-features = false, features = [ ] } iso8601-timestamp = { version = "0.2.17", features = ["diesel-pg"] } kitsune-config = { path = "../kitsune-config" } +kitsune-error = { path = "../kitsune-error" } kitsune-language = { path = "../kitsune-language" } kitsune-type = { path = "../kitsune-type" } num-derive = "0.4.2" @@ -36,12 +37,12 @@ rustls-native-certs = "0.7.0" serde = { version = "1.0.197", features = ["derive"] } simd-json = "0.13.9" speedy-uuid = { path = "../../lib/speedy-uuid", features = ["diesel"] } -thiserror = "1.0.58" tokio = { version = "1.37.0", features = ["rt"] } tokio-postgres = "0.7.10" tokio-postgres-rustls = "0.12.0" tracing = "0.1.40" tracing-log = "0.2.0" +trials = { path = "../../lib/trials" } typed-builder = "0.18.1" [dev-dependencies] diff --git a/crates/kitsune-db/src/error.rs b/crates/kitsune-db/src/error.rs index 12daf1669..a7f0634dc 100644 --- a/crates/kitsune-db/src/error.rs +++ b/crates/kitsune-db/src/error.rs @@ -1,10 +1,5 @@ use core::fmt; -use diesel_async::pooled_connection::bb8; use std::error::Error as StdError; -use thiserror::Error; - -pub type BoxError = Box; -pub type Result = std::result::Result; #[derive(Debug)] pub struct EnumConversionError(pub i32); @@ -35,21 +30,3 @@ impl fmt::Display for IsoCodeConversionError { } impl StdError for IsoCodeConversionError {} - -#[derive(Debug, Error)] -pub enum Error { - #[error(transparent)] - Blocking(#[from] blowocking::Error), - - #[error(transparent)] - Diesel(#[from] diesel::result::Error), - - #[error(transparent)] - DieselConnection(#[from] diesel::result::ConnectionError), - - #[error(transparent)] - Migration(BoxError), - - #[error(transparent)] - Pool(#[from] bb8::RunError), -} diff --git a/crates/kitsune-db/src/lib.rs b/crates/kitsune-db/src/lib.rs index a9e851a86..396db7123 100644 --- a/crates/kitsune-db/src/lib.rs +++ b/crates/kitsune-db/src/lib.rs @@ -9,13 +9,13 @@ use diesel_async::{ }; use diesel_migrations::{embed_migrations, EmbeddedMigrations, MigrationHarness}; use kitsune_config::database::Configuration as DatabaseConfig; +use kitsune_error::{Error, Result}; use tracing_log::LogTracer; pub type PgPool = Pool; -pub use crate::error::{Error, Result}; #[doc(hidden)] -pub use diesel_async; +pub use {diesel_async, kitsune_error, trials}; mod error; mod pool; @@ -45,7 +45,7 @@ pub async fn connect(config: &DatabaseConfig) -> Result { migration_conn .run_pending_migrations(MIGRATIONS) - .map_err(Error::Migration)?; + .map_err(Error::msg)?; Ok::<_, Error>(()) } diff --git a/crates/kitsune-db/src/pool.rs b/crates/kitsune-db/src/pool.rs index 2201c9758..bd8662843 100644 --- a/crates/kitsune-db/src/pool.rs +++ b/crates/kitsune-db/src/pool.rs @@ -7,20 +7,13 @@ macro_rules! with_connection { }}; } -#[macro_export] -macro_rules! catch_error { - ($($tt:tt)*) => {{ - let result: ::std::result::Result<_, ::diesel_async::pooled_connection::bb8::RunError> = async { - Ok({ $($tt)* }) - }.await; - result - }}; -} - #[macro_export] macro_rules! with_connection_panicky { ($pool:expr, $($other:tt)*) => {{ - $crate::catch_error!($crate::with_connection!($pool, $($other)*)).unwrap() + let result: $crate::kitsune_error::Result<_> = $crate::trials::attempt! { async + $crate::with_connection!($pool, $($other)*) + }; + result.unwrap() }}; } diff --git a/kitsune-cli/Cargo.toml b/kitsune-cli/Cargo.toml index d0408b62c..9a3cbd77f 100644 --- a/kitsune-cli/Cargo.toml +++ b/kitsune-cli/Cargo.toml @@ -21,6 +21,7 @@ dotenvy = "0.15.7" envy = "0.4.2" kitsune-config = { path = "../crates/kitsune-config" } kitsune-db = { path = "../crates/kitsune-db" } +kitsune-error = { path = "../crates/kitsune-error" } serde = { version = "1.0.197", features = ["derive"] } speedy-uuid = { path = "../lib/speedy-uuid" } tokio = { version = "1.37.0", features = ["full"] } diff --git a/kitsune-cli/src/main.rs b/kitsune-cli/src/main.rs index add367524..bef612589 100644 --- a/kitsune-cli/src/main.rs +++ b/kitsune-cli/src/main.rs @@ -33,7 +33,8 @@ async fn main() -> Result<()> { max_connections: 1, use_tls: config.database_use_tls, }) - .await?; + .await + .map_err(kitsune_error::Error::into_error)?; let cmd = App::parse(); diff --git a/kitsune-job-runner/src/main.rs b/kitsune-job-runner/src/main.rs index 877581f22..7fee9eb73 100644 --- a/kitsune-job-runner/src/main.rs +++ b/kitsune-job-runner/src/main.rs @@ -30,7 +30,10 @@ async fn main() -> eyre::Result<()> { kitsune_observability::initialise(env!("CARGO_PKG_NAME"), &config)?; - let db_pool = kitsune_db::connect(&config.database).await?; + let db_pool = kitsune_db::connect(&config.database) + .await + .map_err(kitsune_error::Error::into_error)?; + let job_queue = kitsune_job_runner::prepare_job_queue(db_pool.clone(), &config.job_queue).await?; diff --git a/kitsune/Cargo.toml b/kitsune/Cargo.toml index 9b48313dc..748dbb0c2 100644 --- a/kitsune/Cargo.toml +++ b/kitsune/Cargo.toml @@ -103,6 +103,7 @@ tower-http = { version = "0.5.2", features = [ ] } tower-http-digest = { path = "../lib/tower-http-digest" } tracing = "0.1.40" +trials = { path = "../lib/trials" } typed-builder = "0.18.1" url = "2.5.0" utoipa = { version = "4.2.0", features = ["axum_extras", "uuid"] } diff --git a/kitsune/src/http/extractor/auth.rs b/kitsune/src/http/extractor/auth.rs index b6b37e0ea..acd586ceb 100644 --- a/kitsune/src/http/extractor/auth.rs +++ b/kitsune/src/http/extractor/auth.rs @@ -11,13 +11,13 @@ use diesel_async::RunQueryDsl; use headers::{authorization::Bearer, Authorization}; use http::request::Parts; use kitsune_db::{ - catch_error, model::{account::Account, user::User}, schema::{accounts, oauth2_access_tokens, users}, with_connection, }; -use kitsune_error::Error; +use kitsune_error::{Error, Result}; use time::OffsetDateTime; +use trials::attempt; /// Mastodon-specific auth extractor alias /// @@ -64,14 +64,17 @@ impl FromRequestParts .filter(oauth2_access_tokens::expires_at.gt(OffsetDateTime::now_utc())); } - let (user, account) = catch_error!(with_connection!(state.db_pool, |db_conn| { - user_account_query - .select(<(User, Account)>::as_select()) - .get_result(db_conn) - .await - .map_err(Error::from) - })) - .map_err(Error::from)??; + let result: Result<(User, Account)> = attempt! { async + with_connection!(state.db_pool, |db_conn| { + user_account_query + .select(<(User, Account)>::as_select()) + .get_result(db_conn) + .await + .map_err(Error::from) + })? + }; + + let (user, account) = result?; Ok(Self(UserData { account, user })) } diff --git a/kitsune/src/main.rs b/kitsune/src/main.rs index 79f17be8a..107ca69d3 100644 --- a/kitsune/src/main.rs +++ b/kitsune/src/main.rs @@ -27,6 +27,7 @@ async fn boot() -> eyre::Result<()> { let conn = kitsune_db::connect(&config.database) .await + .map_err(kitsune_error::Error::into_error) .wrap_err("Failed to connect to and migrate the database")?; let job_queue = kitsune_job_runner::prepare_job_queue(conn.clone(), &config.job_queue) diff --git a/kitsune/src/oauth2/authorizer.rs b/kitsune/src/oauth2/authorizer.rs index 9c4184413..dbc541a57 100644 --- a/kitsune/src/oauth2/authorizer.rs +++ b/kitsune/src/oauth2/authorizer.rs @@ -3,14 +3,15 @@ use async_trait::async_trait; use diesel::{OptionalExtension, QueryDsl}; use diesel_async::RunQueryDsl; use kitsune_db::{ - catch_error, model::oauth2, schema::{oauth2_applications, oauth2_authorization_codes}, with_connection, PgPool, }; +use kitsune_error::Result; use kitsune_util::generate_secret; use oxide_auth::primitives::grant::{Extensions, Grant}; use oxide_auth_async::primitives::Authorizer; +use trials::attempt; #[derive(Clone)] pub struct OAuthAuthorizer { @@ -19,56 +20,61 @@ pub struct OAuthAuthorizer { #[async_trait] impl Authorizer for OAuthAuthorizer { + #[instrument(skip_all)] async fn authorize(&mut self, grant: Grant) -> Result { - let application_id = grant.client_id.parse().map_err(|_| ())?; - let user_id = grant.owner_id.parse().map_err(|_| ())?; - let scopes = grant.scope.to_string(); - let secret = generate_secret(); - let expires_at = chrono_to_timestamp(grant.until); + let result: Result<_> = attempt! { async + let application_id = grant.client_id.parse()?; + let user_id = grant.owner_id.parse()?; + let scopes = grant.scope.to_string(); + let secret = generate_secret(); + let expires_at = chrono_to_timestamp(grant.until); - catch_error!(with_connection!(self.db_pool, |db_conn| { - diesel::insert_into(oauth2_authorization_codes::table) - .values(oauth2::NewAuthorizationCode { - code: secret.as_str(), - application_id, - user_id, - scopes: scopes.as_str(), - expires_at, - }) - .returning(oauth2_authorization_codes::code) - .get_result(db_conn) - .await - })) - .map_err(|_| ())? - .map_err(|_| ()) + with_connection!(self.db_pool, |db_conn| { + diesel::insert_into(oauth2_authorization_codes::table) + .values(oauth2::NewAuthorizationCode { + code: secret.as_str(), + application_id, + user_id, + scopes: scopes.as_str(), + expires_at, + }) + .returning(oauth2_authorization_codes::code) + .get_result(db_conn) + .await + })? + }; + + result.map_err(|error| debug!(?error, "authorize failed")) } + #[instrument(skip_all)] async fn extract(&mut self, authorization_code: &str) -> Result, ()> { - let oauth_data = catch_error!(with_connection!(self.db_pool, |db_conn| { - oauth2_authorization_codes::table - .find(authorization_code) - .inner_join(oauth2_applications::table) - .first::<(oauth2::AuthorizationCode, oauth2::Application)>(db_conn) - .await - .optional() - })) - .map_err(|_| ())? - .map_err(|_| ())?; + let result: Result<_> = attempt! { async + let oauth_data = with_connection!(self.db_pool, |db_conn| { + oauth2_authorization_codes::table + .find(authorization_code) + .inner_join(oauth2_applications::table) + .first::<(oauth2::AuthorizationCode, oauth2::Application)>(db_conn) + .await + .optional() + })?; - let oauth_data = oauth_data.map(|(code, app)| { - let scope = app.scopes.parse().unwrap(); - let redirect_uri = app.redirect_uri.parse().unwrap(); + oauth_data + .map(|(code, app)| { + let scope = app.scopes.parse().unwrap(); + let redirect_uri = app.redirect_uri.parse().unwrap(); - Grant { - owner_id: code.user_id.to_string(), - client_id: code.application_id.to_string(), - scope, - redirect_uri, - until: timestamp_to_chrono(code.expires_at), - extensions: Extensions::default(), - } - }); + Grant { + owner_id: code.user_id.to_string(), + client_id: code.application_id.to_string(), + scope, + redirect_uri, + until: timestamp_to_chrono(code.expires_at), + extensions: Extensions::default(), + } + }) + }; - Ok(oauth_data) + result.map_err(|error| debug!(?error, "extract failed")) } } diff --git a/kitsune/src/oauth2/issuer.rs b/kitsune/src/oauth2/issuer.rs index 68b9ea849..ba0f4ba95 100644 --- a/kitsune/src/oauth2/issuer.rs +++ b/kitsune/src/oauth2/issuer.rs @@ -3,12 +3,11 @@ use async_trait::async_trait; use diesel::{ExpressionMethods, OptionalExtension, QueryDsl, SelectableHelper}; use diesel_async::RunQueryDsl; use kitsune_db::{ - catch_error, model::oauth2, schema::{oauth2_access_tokens, oauth2_applications, oauth2_refresh_tokens}, with_connection, with_transaction, PgPool, }; -use kitsune_error::Error; +use kitsune_error::{Error, Result}; use kitsune_util::generate_secret; use oxide_auth::primitives::{ grant::{Extensions, Grant}, @@ -16,6 +15,7 @@ use oxide_auth::primitives::{ prelude::IssuedToken, }; use oxide_auth_async::primitives::Issuer; +use trials::attempt; #[derive(Clone)] pub struct OAuthIssuer { @@ -24,161 +24,167 @@ pub struct OAuthIssuer { #[async_trait] impl Issuer for OAuthIssuer { + #[instrument(skip_all)] async fn issue(&mut self, grant: Grant) -> Result { - let application_id = grant.client_id.parse().map_err(|_| ())?; - let user_id = grant.owner_id.parse().map_err(|_| ())?; - let scopes = grant.scope.to_string(); - let expires_at = chrono_to_timestamp(grant.until); - - let (access_token, refresh_token) = catch_error!(with_transaction!(self.db_pool, |tx| { - let access_token = diesel::insert_into(oauth2_access_tokens::table) - .values(oauth2::NewAccessToken { - token: generate_secret().as_str(), - user_id: Some(user_id), - application_id: Some(application_id), - scopes: scopes.as_str(), - expires_at, - }) - .returning(oauth2::AccessToken::as_returning()) - .get_result::(tx) - .await?; - - let refresh_token = diesel::insert_into(oauth2_refresh_tokens::table) - .values(oauth2::NewRefreshToken { - token: generate_secret().as_str(), - access_token: access_token.token.as_str(), - application_id, - }) - .returning(oauth2::RefreshToken::as_returning()) - .get_result::(tx) - .await?; - - Ok::<_, Error>((access_token, refresh_token)) - })) - .map_err(|_| ())? - .map_err(|_| ())?; - - Ok(IssuedToken { - token: access_token.token, - refresh: Some(refresh_token.token), - until: grant.until, - token_type: TokenType::Bearer, - }) + let result: Result<_> = attempt! { async + let application_id = grant.client_id.parse()?; + let user_id = grant.owner_id.parse()?; + let scopes = grant.scope.to_string(); + let expires_at = chrono_to_timestamp(grant.until); + + let (access_token, refresh_token) = with_transaction!(self.db_pool, |tx| { + let access_token = diesel::insert_into(oauth2_access_tokens::table) + .values(oauth2::NewAccessToken { + token: generate_secret().as_str(), + user_id: Some(user_id), + application_id: Some(application_id), + scopes: scopes.as_str(), + expires_at, + }) + .returning(oauth2::AccessToken::as_returning()) + .get_result::(tx) + .await?; + + let refresh_token = diesel::insert_into(oauth2_refresh_tokens::table) + .values(oauth2::NewRefreshToken { + token: generate_secret().as_str(), + access_token: access_token.token.as_str(), + application_id, + }) + .returning(oauth2::RefreshToken::as_returning()) + .get_result::(tx) + .await?; + + Ok::<_, Error>((access_token, refresh_token)) + })?; + + IssuedToken { + token: access_token.token, + refresh: Some(refresh_token.token), + until: grant.until, + token_type: TokenType::Bearer, + } + }; + + result.map_err(|error| debug!(?error, "failed to issue token")) } + #[instrument(skip_all)] async fn refresh(&mut self, refresh_token: &str, grant: Grant) -> Result { - let (refresh_token, access_token) = - catch_error!(with_connection!(self.db_pool, |db_conn| { - oauth2_refresh_tokens::table - .find(refresh_token) - .inner_join(oauth2_access_tokens::table) - .select(<(oauth2::RefreshToken, oauth2::AccessToken)>::as_select()) - .get_result::<(oauth2::RefreshToken, oauth2::AccessToken)>(db_conn) - .await - })) - .map_err(|_| ())? - .map_err(|_| ())?; - - let (access_token, refresh_token) = catch_error!(with_transaction!(self.db_pool, |tx| { - let new_access_token = diesel::insert_into(oauth2_access_tokens::table) - .values(oauth2::NewAccessToken { - user_id: access_token.user_id, - token: generate_secret().as_str(), - application_id: access_token.application_id, - scopes: access_token.scopes.as_str(), - expires_at: chrono_to_timestamp(grant.until), - }) - .get_result::(tx) - .await?; - - let refresh_token = diesel::update(&refresh_token) - .set(oauth2_refresh_tokens::access_token.eq(new_access_token.token.as_str())) - .get_result::(tx) - .await?; - - diesel::delete(&access_token).execute(tx).await?; - - Ok::<_, Error>((new_access_token, refresh_token)) - })) - .map_err(|_| ())? - .map_err(|_| ())?; - - Ok(RefreshedToken { - token: access_token.token, - refresh: Some(refresh_token.token), - until: timestamp_to_chrono(access_token.expires_at), - token_type: TokenType::Bearer, - }) + let result: Result<_> = attempt! { async + let (refresh_token, access_token) = + with_connection!(self.db_pool, |db_conn| { + oauth2_refresh_tokens::table + .find(refresh_token) + .inner_join(oauth2_access_tokens::table) + .select(<(oauth2::RefreshToken, oauth2::AccessToken)>::as_select()) + .get_result::<(oauth2::RefreshToken, oauth2::AccessToken)>(db_conn) + .await + })?; + + let (access_token, refresh_token) = with_transaction!(self.db_pool, |tx| { + let new_access_token = diesel::insert_into(oauth2_access_tokens::table) + .values(oauth2::NewAccessToken { + user_id: access_token.user_id, + token: generate_secret().as_str(), + application_id: access_token.application_id, + scopes: access_token.scopes.as_str(), + expires_at: chrono_to_timestamp(grant.until), + }) + .get_result::(tx) + .await?; + + let refresh_token = diesel::update(&refresh_token) + .set(oauth2_refresh_tokens::access_token.eq(new_access_token.token.as_str())) + .get_result::(tx) + .await?; + + diesel::delete(&access_token).execute(tx).await?; + + Ok::<_, Error>((new_access_token, refresh_token)) + })?; + + RefreshedToken { + token: access_token.token, + refresh: Some(refresh_token.token), + until: timestamp_to_chrono(access_token.expires_at), + token_type: TokenType::Bearer, + } + }; + + result.map_err(|error| debug!(?error, "failed to refresh token")) } + #[instrument(skip_all)] async fn recover_token(&mut self, access_token: &str) -> Result, ()> { - let oauth_data = catch_error!(with_connection!(self.db_pool, |db_conn| { - oauth2_access_tokens::table - .find(access_token) - .inner_join(oauth2_applications::table) - .select(<(oauth2::AccessToken, oauth2::Application)>::as_select()) - .get_result::<(oauth2::AccessToken, oauth2::Application)>(db_conn) - .await - .optional() - })) - .map_err(|_| ())? - .map_err(|_| ())?; - - let oauth_data = oauth_data.map(|(access_token, app)| { - let scope = app.scopes.parse().unwrap(); - let redirect_uri = app.redirect_uri.parse().unwrap(); - let until = timestamp_to_chrono(access_token.expires_at); - - Grant { - owner_id: access_token - .user_id - .as_ref() - .map(ToString::to_string) - .unwrap_or_default(), - client_id: app.id.to_string(), - scope, - redirect_uri, - until, - extensions: Extensions::default(), - } - }); - - Ok(oauth_data) + let result: Result<_> = attempt! { async + let oauth_data = with_connection!(self.db_pool, |db_conn| { + oauth2_access_tokens::table + .find(access_token) + .inner_join(oauth2_applications::table) + .select(<(oauth2::AccessToken, oauth2::Application)>::as_select()) + .get_result::<(oauth2::AccessToken, oauth2::Application)>(db_conn) + .await + .optional() + })?; + + oauth_data.map(|(access_token, app)| { + let scope = app.scopes.parse().unwrap(); + let redirect_uri = app.redirect_uri.parse().unwrap(); + let until = timestamp_to_chrono(access_token.expires_at); + + Grant { + owner_id: access_token + .user_id + .as_ref() + .map(ToString::to_string) + .unwrap_or_default(), + client_id: app.id.to_string(), + scope, + redirect_uri, + until, + extensions: Extensions::default(), + } + }) + }; + + result.map_err(|error| debug!(?error, "failed to recover token grant")) } + #[instrument(skip_all)] async fn recover_refresh(&mut self, refresh_token: &str) -> Result, ()> { - let oauth_data = catch_error!(with_connection!(self.db_pool, |db_conn| { - oauth2_refresh_tokens::table - .find(refresh_token) - .inner_join(oauth2_access_tokens::table) - .inner_join(oauth2_applications::table) - .select(<(oauth2::AccessToken, oauth2::Application)>::as_select()) - .get_result::<(oauth2::AccessToken, oauth2::Application)>(db_conn) - .await - .optional() - })) - .map_err(|_| ())? - .map_err(|_| ())?; - - let oauth_data = oauth_data.map(|(access_token, app)| { - let scope = access_token.scopes.parse().unwrap(); - let redirect_uri = app.redirect_uri.parse().unwrap(); - let until = chrono::NaiveDateTime::MAX.and_utc(); - - Grant { - owner_id: access_token - .user_id - .as_ref() - .map(ToString::to_string) - .unwrap_or_default(), - client_id: app.id.to_string(), - scope, - redirect_uri, - until, - extensions: Extensions::default(), - } - }); - - Ok(oauth_data) + let result: Result<_> = attempt! { async + let oauth_data = with_connection!(self.db_pool, |db_conn| { + oauth2_refresh_tokens::table + .find(refresh_token) + .inner_join(oauth2_access_tokens::table) + .inner_join(oauth2_applications::table) + .select(<(oauth2::AccessToken, oauth2::Application)>::as_select()) + .get_result::<(oauth2::AccessToken, oauth2::Application)>(db_conn) + .await + .optional() + })?; + + oauth_data.map(|(access_token, app)| { + let scope = access_token.scopes.parse().unwrap(); + let redirect_uri = app.redirect_uri.parse().unwrap(); + let until = chrono::NaiveDateTime::MAX.and_utc(); + + Grant { + owner_id: access_token + .user_id + .as_ref() + .map(ToString::to_string) + .unwrap_or_default(), + client_id: app.id.to_string(), + scope, + redirect_uri, + until, + extensions: Extensions::default(), + } + }) + }; + + result.map_err(|error| debug!(?error, "failed to recover refresh grant")) } } diff --git a/kitsune/src/oauth2/registrar.rs b/kitsune/src/oauth2/registrar.rs index d6f567974..57d6f1765 100644 --- a/kitsune/src/oauth2/registrar.rs +++ b/kitsune/src/oauth2/registrar.rs @@ -1,9 +1,8 @@ use async_trait::async_trait; use diesel::{ExpressionMethods, OptionalExtension, QueryDsl}; use diesel_async::RunQueryDsl; -use kitsune_db::{ - catch_error, model::oauth2, schema::oauth2_applications, with_connection, PgPool, -}; +use kitsune_db::{model::oauth2, schema::oauth2_applications, with_connection, PgPool}; +use kitsune_error::Result; use oxide_auth::{ endpoint::{PreGrant, Scope}, primitives::registrar::{BoundClient, ClientUrl, ExactUrl, RegisteredUrl, RegistrarError}, @@ -14,6 +13,7 @@ use std::{ borrow::Cow, str::{self, FromStr}, }; +use trials::attempt; use super::OAuthScope; @@ -24,6 +24,7 @@ pub struct OAuthRegistrar { #[async_trait] impl Registrar for OAuthRegistrar { + #[instrument(skip_all)] async fn bound_redirect<'a>( &self, bound: ClientUrl<'a>, @@ -38,6 +39,7 @@ impl Registrar for OAuthRegistrar { } } + #[instrument(skip_all)] async fn negotiate<'a>( &self, client: BoundClient<'a>, @@ -48,17 +50,20 @@ impl Registrar for OAuthRegistrar { .parse() .map_err(|_| RegistrarError::PrimitiveError)?; - let client = catch_error!(with_connection!(self.db_pool, |db_conn| { - oauth2_applications::table - .find(client_id) - .filter(oauth2_applications::redirect_uri.eq(client.redirect_uri.as_str())) - .get_result::(db_conn) - .await - .optional() - })) - .map_err(|_| RegistrarError::PrimitiveError)? - .map_err(|_| RegistrarError::PrimitiveError)? - .ok_or(RegistrarError::Unspecified)?; + let client_result: Result<_> = attempt! { async + with_connection!(self.db_pool, |db_conn| { + oauth2_applications::table + .find(client_id) + .filter(oauth2_applications::redirect_uri.eq(client.redirect_uri.as_str())) + .get_result::(db_conn) + .await + .optional() + })? + }; + + let client = client_result + .map_err(|_| RegistrarError::PrimitiveError)? + .ok_or(RegistrarError::Unspecified)?; let client_id = client.id.to_string(); let redirect_uri = ExactUrl::new(client.redirect_uri) @@ -91,6 +96,7 @@ impl Registrar for OAuthRegistrar { }) } + #[instrument(skip_all)] async fn check( &self, client_id: &str, @@ -107,16 +113,19 @@ impl Registrar for OAuthRegistrar { client_query = client_query.filter(oauth2_applications::secret.eq(passphrase)); } - catch_error!(with_connection!(self.db_pool, |db_conn| { - client_query - .select(oauth2_applications::id) - .execute(db_conn) - .await - .optional() - })) - .map_err(|_| RegistrarError::PrimitiveError)? - .map_err(|_| RegistrarError::PrimitiveError)? - .map(|_| ()) - .ok_or(RegistrarError::Unspecified) + let result: Result<_> = attempt! { async + with_connection!(self.db_pool, |db_conn| { + client_query + .select(oauth2_applications::id) + .execute(db_conn) + .await + .optional() + })? + }; + + result + .map_err(|_| RegistrarError::PrimitiveError)? + .map(|_| ()) + .ok_or(RegistrarError::Unspecified) } } diff --git a/kitsune/src/oauth2/solicitor.rs b/kitsune/src/oauth2/solicitor.rs index a2b0371ed..7468e8f0a 100644 --- a/kitsune/src/oauth2/solicitor.rs +++ b/kitsune/src/oauth2/solicitor.rs @@ -4,15 +4,15 @@ use async_trait::async_trait; use cursiv::CsrfHandle; use diesel::{OptionalExtension, QueryDsl}; use diesel_async::RunQueryDsl; -use kitsune_db::{ - catch_error, model::user::User, schema::oauth2_applications, with_connection, PgPool, -}; +use kitsune_db::{model::user::User, schema::oauth2_applications, with_connection, PgPool}; +use kitsune_error::Result; use oxide_auth::endpoint::{OAuthError, OwnerConsent, QueryParameter, Solicitation, WebRequest}; use oxide_auth_async::endpoint::OwnerSolicitor; use oxide_auth_axum::{OAuthRequest, OAuthResponse, WebError}; use speedy_uuid::Uuid; use std::{borrow::Cow, str::FromStr}; use strum::EnumMessage; +use trials::attempt; use typed_builder::TypedBuilder; #[derive(Template)] @@ -55,6 +55,7 @@ pub struct OAuthOwnerSolicitor { } impl OAuthOwnerSolicitor { + #[instrument(skip_all)] async fn check_consent( &self, login_consent: Option<&str>, @@ -81,17 +82,20 @@ impl OAuthOwnerSolicitor { .parse() .map_err(|_| WebError::Endpoint(OAuthError::BadRequest))?; - let app_name = catch_error!(with_connection!(self.db_pool, |db_conn| { - oauth2_applications::table - .find(client_id) - .select(oauth2_applications::name) - .get_result::(db_conn) - .await - .optional() - })) - .map_err(|_| WebError::InternalError(None))? - .map_err(|_| WebError::InternalError(None))? - .ok_or(WebError::Endpoint(OAuthError::DenySilently))?; + let app_name_result: Result> = attempt! { async + with_connection!(self.db_pool, |db_conn| { + oauth2_applications::table + .find(client_id) + .select(oauth2_applications::name) + .get_result::(db_conn) + .await + .optional() + })? + }; + + let app_name = app_name_result + .map_err(|_| WebError::InternalError(None))? + .ok_or(WebError::Endpoint(OAuthError::DenySilently))?; let scopes = solicitation .pre_grant() @@ -129,6 +133,7 @@ impl OAuthOwnerSolicitor { #[async_trait] impl OwnerSolicitor for OAuthOwnerSolicitor { + #[instrument(skip_all)] async fn check_consent( &mut self, req: &mut OAuthRequest,