Skip to content

Commit

Permalink
use try-block polyfills everywhere
Browse files Browse the repository at this point in the history
  • Loading branch information
aumetra committed Apr 7, 2024
1 parent ed555ea commit 51b6f91
Show file tree
Hide file tree
Showing 15 changed files with 292 additions and 282 deletions.
5 changes: 4 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion crates/kitsune-db/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ futures-util = { version = "0.3.30", default-features = false, features = [
] }
iso8601-timestamp = { version = "0.2.17", features = ["diesel-pg"] }
kitsune-config = { path = "../kitsune-config" }
kitsune-error = { path = "../kitsune-error" }
kitsune-language = { path = "../kitsune-language" }
kitsune-type = { path = "../kitsune-type" }
num-derive = "0.4.2"
Expand All @@ -36,12 +37,12 @@ rustls-native-certs = "0.7.0"
serde = { version = "1.0.197", features = ["derive"] }
simd-json = "0.13.9"
speedy-uuid = { path = "../../lib/speedy-uuid", features = ["diesel"] }
thiserror = "1.0.58"
tokio = { version = "1.37.0", features = ["rt"] }
tokio-postgres = "0.7.10"
tokio-postgres-rustls = "0.12.0"
tracing = "0.1.40"
tracing-log = "0.2.0"
trials = { path = "../../lib/trials" }
typed-builder = "0.18.1"

[dev-dependencies]
Expand Down
23 changes: 0 additions & 23 deletions crates/kitsune-db/src/error.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
use core::fmt;
use diesel_async::pooled_connection::bb8;
use std::error::Error as StdError;
use thiserror::Error;

pub type BoxError = Box<dyn StdError + Send + Sync>;
pub type Result<T, E = Error> = std::result::Result<T, E>;

#[derive(Debug)]
pub struct EnumConversionError(pub i32);
Expand Down Expand Up @@ -35,21 +30,3 @@ impl fmt::Display for IsoCodeConversionError {
}

impl StdError for IsoCodeConversionError {}

#[derive(Debug, Error)]
pub enum Error {
#[error(transparent)]
Blocking(#[from] blowocking::Error),

#[error(transparent)]
Diesel(#[from] diesel::result::Error),

#[error(transparent)]
DieselConnection(#[from] diesel::result::ConnectionError),

#[error(transparent)]
Migration(BoxError),

#[error(transparent)]
Pool(#[from] bb8::RunError),
}
6 changes: 3 additions & 3 deletions crates/kitsune-db/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ use diesel_async::{
};
use diesel_migrations::{embed_migrations, EmbeddedMigrations, MigrationHarness};
use kitsune_config::database::Configuration as DatabaseConfig;
use kitsune_error::{Error, Result};
use tracing_log::LogTracer;

pub type PgPool = Pool<AsyncPgConnection>;

pub use crate::error::{Error, Result};
#[doc(hidden)]
pub use diesel_async;
pub use {diesel_async, kitsune_error, trials};

mod error;
mod pool;
Expand Down Expand Up @@ -45,7 +45,7 @@ pub async fn connect(config: &DatabaseConfig) -> Result<PgPool> {

migration_conn
.run_pending_migrations(MIGRATIONS)
.map_err(Error::Migration)?;
.map_err(Error::msg)?;

Ok::<_, Error>(())
}
Expand Down
15 changes: 4 additions & 11 deletions crates/kitsune-db/src/pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,13 @@ macro_rules! with_connection {
}};
}

#[macro_export]
macro_rules! catch_error {
($($tt:tt)*) => {{
let result: ::std::result::Result<_, ::diesel_async::pooled_connection::bb8::RunError> = async {
Ok({ $($tt)* })
}.await;
result
}};
}

#[macro_export]
macro_rules! with_connection_panicky {
($pool:expr, $($other:tt)*) => {{
$crate::catch_error!($crate::with_connection!($pool, $($other)*)).unwrap()
let result: $crate::kitsune_error::Result<_> = $crate::trials::attempt! { async
$crate::with_connection!($pool, $($other)*)
};
result.unwrap()
}};
}

Expand Down
1 change: 1 addition & 0 deletions kitsune-cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ dotenvy = "0.15.7"
envy = "0.4.2"
kitsune-config = { path = "../crates/kitsune-config" }
kitsune-db = { path = "../crates/kitsune-db" }
kitsune-error = { path = "../crates/kitsune-error" }
serde = { version = "1.0.197", features = ["derive"] }
speedy-uuid = { path = "../lib/speedy-uuid" }
tokio = { version = "1.37.0", features = ["full"] }
Expand Down
3 changes: 2 additions & 1 deletion kitsune-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ async fn main() -> Result<()> {
max_connections: 1,
use_tls: config.database_use_tls,
})
.await?;
.await
.map_err(kitsune_error::Error::into_error)?;

let cmd = App::parse();

Expand Down
5 changes: 4 additions & 1 deletion kitsune-job-runner/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@ async fn main() -> eyre::Result<()> {

kitsune_observability::initialise(env!("CARGO_PKG_NAME"), &config)?;

let db_pool = kitsune_db::connect(&config.database).await?;
let db_pool = kitsune_db::connect(&config.database)
.await
.map_err(kitsune_error::Error::into_error)?;

let job_queue =
kitsune_job_runner::prepare_job_queue(db_pool.clone(), &config.job_queue).await?;

Expand Down
1 change: 1 addition & 0 deletions kitsune/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ tower-http = { version = "0.5.2", features = [
] }
tower-http-digest = { path = "../lib/tower-http-digest" }
tracing = "0.1.40"
trials = { path = "../lib/trials" }
typed-builder = "0.18.1"
url = "2.5.0"
utoipa = { version = "4.2.0", features = ["axum_extras", "uuid"] }
Expand Down
23 changes: 13 additions & 10 deletions kitsune/src/http/extractor/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@ use diesel_async::RunQueryDsl;
use headers::{authorization::Bearer, Authorization};
use http::request::Parts;
use kitsune_db::{
catch_error,
model::{account::Account, user::User},
schema::{accounts, oauth2_access_tokens, users},
with_connection,
};
use kitsune_error::Error;
use kitsune_error::{Error, Result};
use time::OffsetDateTime;
use trials::attempt;

/// Mastodon-specific auth extractor alias
///
Expand Down Expand Up @@ -64,14 +64,17 @@ impl<const ENFORCE_EXPIRATION: bool> FromRequestParts<Zustand>
.filter(oauth2_access_tokens::expires_at.gt(OffsetDateTime::now_utc()));
}

let (user, account) = catch_error!(with_connection!(state.db_pool, |db_conn| {
user_account_query
.select(<(User, Account)>::as_select())
.get_result(db_conn)
.await
.map_err(Error::from)
}))
.map_err(Error::from)??;
let result: Result<(User, Account)> = attempt! { async
with_connection!(state.db_pool, |db_conn| {
user_account_query
.select(<(User, Account)>::as_select())
.get_result(db_conn)
.await
.map_err(Error::from)
})?
};

let (user, account) = result?;

Ok(Self(UserData { account, user }))
}
Expand Down
1 change: 1 addition & 0 deletions kitsune/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ async fn boot() -> eyre::Result<()> {

let conn = kitsune_db::connect(&config.database)
.await
.map_err(kitsune_error::Error::into_error)
.wrap_err("Failed to connect to and migrate the database")?;

let job_queue = kitsune_job_runner::prepare_job_queue(conn.clone(), &config.job_queue)
Expand Down
94 changes: 50 additions & 44 deletions kitsune/src/oauth2/authorizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@ use async_trait::async_trait;
use diesel::{OptionalExtension, QueryDsl};
use diesel_async::RunQueryDsl;
use kitsune_db::{
catch_error,
model::oauth2,
schema::{oauth2_applications, oauth2_authorization_codes},
with_connection, PgPool,
};
use kitsune_error::Result;
use kitsune_util::generate_secret;
use oxide_auth::primitives::grant::{Extensions, Grant};
use oxide_auth_async::primitives::Authorizer;
use trials::attempt;

#[derive(Clone)]
pub struct OAuthAuthorizer {
Expand All @@ -19,56 +20,61 @@ pub struct OAuthAuthorizer {

#[async_trait]
impl Authorizer for OAuthAuthorizer {
#[instrument(skip_all)]
async fn authorize(&mut self, grant: Grant) -> Result<String, ()> {
let application_id = grant.client_id.parse().map_err(|_| ())?;
let user_id = grant.owner_id.parse().map_err(|_| ())?;
let scopes = grant.scope.to_string();
let secret = generate_secret();
let expires_at = chrono_to_timestamp(grant.until);
let result: Result<_> = attempt! { async
let application_id = grant.client_id.parse()?;
let user_id = grant.owner_id.parse()?;
let scopes = grant.scope.to_string();
let secret = generate_secret();
let expires_at = chrono_to_timestamp(grant.until);

catch_error!(with_connection!(self.db_pool, |db_conn| {
diesel::insert_into(oauth2_authorization_codes::table)
.values(oauth2::NewAuthorizationCode {
code: secret.as_str(),
application_id,
user_id,
scopes: scopes.as_str(),
expires_at,
})
.returning(oauth2_authorization_codes::code)
.get_result(db_conn)
.await
}))
.map_err(|_| ())?
.map_err(|_| ())
with_connection!(self.db_pool, |db_conn| {
diesel::insert_into(oauth2_authorization_codes::table)
.values(oauth2::NewAuthorizationCode {
code: secret.as_str(),
application_id,
user_id,
scopes: scopes.as_str(),
expires_at,
})
.returning(oauth2_authorization_codes::code)
.get_result(db_conn)
.await
})?
};

result.map_err(|error| debug!(?error, "authorize failed"))
}

#[instrument(skip_all)]
async fn extract(&mut self, authorization_code: &str) -> Result<Option<Grant>, ()> {
let oauth_data = catch_error!(with_connection!(self.db_pool, |db_conn| {
oauth2_authorization_codes::table
.find(authorization_code)
.inner_join(oauth2_applications::table)
.first::<(oauth2::AuthorizationCode, oauth2::Application)>(db_conn)
.await
.optional()
}))
.map_err(|_| ())?
.map_err(|_| ())?;
let result: Result<_> = attempt! { async
let oauth_data = with_connection!(self.db_pool, |db_conn| {
oauth2_authorization_codes::table
.find(authorization_code)
.inner_join(oauth2_applications::table)
.first::<(oauth2::AuthorizationCode, oauth2::Application)>(db_conn)
.await
.optional()
})?;

let oauth_data = oauth_data.map(|(code, app)| {
let scope = app.scopes.parse().unwrap();
let redirect_uri = app.redirect_uri.parse().unwrap();
oauth_data
.map(|(code, app)| {
let scope = app.scopes.parse().unwrap();
let redirect_uri = app.redirect_uri.parse().unwrap();

Grant {
owner_id: code.user_id.to_string(),
client_id: code.application_id.to_string(),
scope,
redirect_uri,
until: timestamp_to_chrono(code.expires_at),
extensions: Extensions::default(),
}
});
Grant {
owner_id: code.user_id.to_string(),
client_id: code.application_id.to_string(),
scope,
redirect_uri,
until: timestamp_to_chrono(code.expires_at),
extensions: Extensions::default(),
}
})
};

Ok(oauth_data)
result.map_err(|error| debug!(?error, "extract failed"))
}
}
Loading

0 comments on commit 51b6f91

Please sign in to comment.