diff --git a/Cargo.lock b/Cargo.lock index 66a577b85..e54d8bbae 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3645,6 +3645,7 @@ dependencies = [ "camino", "chrono", "clap", + "color-eyre", "cursiv", "der", "diesel", @@ -3729,6 +3730,7 @@ dependencies = [ "base64-simd", "diesel", "diesel-async", + "eyre", "futures-util", "headers", "http 1.1.0", @@ -3972,6 +3974,7 @@ version = "0.0.1-pre.6" dependencies = [ "athena", "clap", + "color-eyre", "just-retry", "kitsune-config", "kitsune-core", @@ -4005,7 +4008,6 @@ dependencies = [ "kitsune-core", "kitsune-db", "kitsune-email", - "scoped-futures", "serde", "speedy-uuid", "tracing", @@ -4043,7 +4045,6 @@ dependencies = [ "kitsune-url", "kitsune-util", "mime", - "scoped-futures", "serde", "simd-json", "smol_str", @@ -4216,7 +4217,6 @@ dependencies = [ "redis", "rsa", "rusty-s3", - "scoped-futures", "serde", "simd-json", "smol_str", diff --git a/crates/kitsune-activitypub/Cargo.toml b/crates/kitsune-activitypub/Cargo.toml index 25f7d908b..8dd1c6c02 100644 --- a/crates/kitsune-activitypub/Cargo.toml +++ b/crates/kitsune-activitypub/Cargo.toml @@ -11,6 +11,7 @@ autometrics = { version = "1.0.1", default-features = false } base64-simd = "0.8.0" diesel = "2.1.5" diesel-async = "0.4.1" +eyre = "0.6.12" futures-util = "0.3.30" headers = "0.4.0" http = "1.1.0" diff --git a/crates/kitsune-activitypub/src/deliverer/mod.rs b/crates/kitsune-activitypub/src/deliverer/mod.rs index 770df0a2b..5e7266ef8 100644 --- a/crates/kitsune-activitypub/src/deliverer/mod.rs +++ b/crates/kitsune-activitypub/src/deliverer/mod.rs @@ -11,20 +11,16 @@ use diesel::{ use diesel_async::RunQueryDsl; use futures_util::TryStreamExt; use iso8601_timestamp::Timestamp; -use kitsune_core::{ - error::BoxError, - traits::{deliverer::Action, Deliverer as DelivererTrait}, -}; +use kitsune_core::traits::{deliverer::Action, Deliverer as DelivererTrait}; use kitsune_db::{ model::{account::Account, favourite::Favourite, follower::Follow, post::Post, user::User}, schema::{accounts, posts, users}, - PgPool, + with_connection, PgPool, }; use kitsune_service::attachment::AttachmentService; use kitsune_type::ap::{ap_context, Activity, ActivityType, ObjectField}; use kitsune_url::UrlService; use kitsune_util::try_join; -use scoped_futures::ScopedFutureExt; use std::sync::Arc; use typed_builder::TypedBuilder; @@ -61,26 +57,21 @@ impl Deliverer { } async fn accept_follow(&self, follow: Follow) -> Result<()> { - let (follower_inbox_url, (followed_account, followed_user)): (String, _) = self - .db_pool - .with_connection(|db_conn| { - async move { - let follower_inbox_url_fut = accounts::table - .find(follow.follower_id) - .select(accounts::inbox_url.assume_not_null()) - .get_result::(db_conn); - - let followed_info_fut = accounts::table - .find(follow.account_id) - .inner_join(users::table.on(accounts::id.eq(users::account_id))) - .select(<(Account, User)>::as_select()) - .get_result::<(Account, User)>(db_conn); - - try_join!(follower_inbox_url_fut, followed_info_fut) - } - .scoped() - }) - .await?; + let (follower_inbox_url, (followed_account, followed_user)): (String, _) = + with_connection!(self.db_pool, |db_conn| { + let follower_inbox_url_fut = accounts::table + .find(follow.follower_id) + .select(accounts::inbox_url.assume_not_null()) + .get_result::(db_conn); + + let followed_info_fut = accounts::table + .find(follow.account_id) + .inner_join(users::table.on(accounts::id.eq(users::account_id))) + .select(<(Account, User)>::as_select()) + .get_result::<(Account, User)>(db_conn); + + try_join!(follower_inbox_url_fut, followed_info_fut) + })?; let followed_account_url = self.service.url.user_url(followed_account.id); @@ -111,17 +102,14 @@ impl Deliverer { } async fn create_or_repost(&self, post: Post) -> Result<()> { - let (account, user) = self - .db_pool - .with_connection(|db_conn| { - accounts::table - .find(post.account_id) - .inner_join(users::table) - .select(<(Account, User)>::as_select()) - .get_result::<(Account, User)>(db_conn) - .scoped() - }) - .await?; + let (account, user) = with_connection!(self.db_pool, |db_conn| { + accounts::table + .find(post.account_id) + .inner_join(users::table) + .select(<(Account, User)>::as_select()) + .get_result::<(Account, User)>(db_conn) + .await + })?; let inbox_stream = self .inbox_resolver @@ -141,21 +129,15 @@ impl Deliverer { } async fn delete_or_unrepost(&self, post: Post) -> Result<()> { - let account_user_data = self - .db_pool - .with_connection(|db_conn| { - async move { - accounts::table - .find(post.account_id) - .inner_join(users::table) - .select(<(Account, User)>::as_select()) - .get_result::<(Account, User)>(db_conn) - .await - .optional() - } - .scoped() - }) - .await?; + let account_user_data = with_connection!(self.db_pool, |db_conn| { + accounts::table + .find(post.account_id) + .inner_join(users::table) + .select(<(Account, User)>::as_select()) + .get_result::<(Account, User)>(db_conn) + .await + .optional() + })?; let Some((account, user)) = account_user_data else { return Ok(()); @@ -179,27 +161,21 @@ impl Deliverer { } async fn favourite(&self, favourite: Favourite) -> Result<()> { - let ((account, user), inbox_url) = self - .db_pool - .with_connection(|db_conn| { - async move { - let account_user_fut = accounts::table - .find(favourite.account_id) - .inner_join(users::table) - .select(<(Account, User)>::as_select()) - .get_result(db_conn); - - let inbox_url_fut = posts::table - .find(favourite.post_id) - .inner_join(accounts::table) - .select(accounts::inbox_url) - .get_result::>(db_conn); - - try_join!(account_user_fut, inbox_url_fut) - } - .scoped() - }) - .await?; + let ((account, user), inbox_url) = with_connection!(self.db_pool, |db_conn| { + let account_user_fut = accounts::table + .find(favourite.account_id) + .inner_join(users::table) + .select(<(Account, User)>::as_select()) + .get_result(db_conn); + + let inbox_url_fut = posts::table + .find(favourite.post_id) + .inner_join(accounts::table) + .select(accounts::inbox_url) + .get_result::>(db_conn); + + try_join!(account_user_fut, inbox_url_fut) + })?; if let Some(ref inbox_url) = inbox_url { let activity = favourite.into_activity(self.mapping_state()).await?; @@ -213,26 +189,21 @@ impl Deliverer { } async fn follow(&self, follow: Follow) -> Result<()> { - let ((follower, follower_user), followed_inbox) = self - .db_pool - .with_connection(|db_conn| { - async move { - let follower_info_fut = accounts::table - .find(follow.follower_id) - .inner_join(users::table) - .select(<(Account, User)>::as_select()) - .get_result::<(Account, User)>(db_conn); - - let followed_inbox_fut = accounts::table - .find(follow.account_id) - .select(accounts::inbox_url) - .get_result::>(db_conn); - - try_join!(follower_info_fut, followed_inbox_fut) - } - .scoped() - }) - .await?; + let ((follower, follower_user), followed_inbox) = + with_connection!(self.db_pool, |db_conn| { + let follower_info_fut = accounts::table + .find(follow.follower_id) + .inner_join(users::table) + .select(<(Account, User)>::as_select()) + .get_result::<(Account, User)>(db_conn); + + let followed_inbox_fut = accounts::table + .find(follow.account_id) + .select(accounts::inbox_url) + .get_result::>(db_conn); + + try_join!(follower_info_fut, followed_inbox_fut) + })?; if let Some(followed_inbox) = followed_inbox { let follow_activity = follow.into_activity(self.mapping_state()).await?; @@ -246,28 +217,23 @@ impl Deliverer { } async fn reject_follow(&self, follow: Follow) -> Result<()> { - let (follower_inbox_url, (followed_account, followed_user), _delete_result) = self - .db_pool - .with_connection(|db_conn| { - async { - let follower_inbox_url_fut = accounts::table - .find(follow.follower_id) - .select(accounts::inbox_url.assume_not_null()) - .get_result::(db_conn); - - let followed_info_fut = accounts::table - .find(follow.account_id) - .inner_join(users::table.on(accounts::id.eq(users::account_id))) - .select(<(Account, User)>::as_select()) - .get_result::<(Account, User)>(db_conn); - - let delete_fut = diesel::delete(&follow).execute(db_conn); - - try_join!(follower_inbox_url_fut, followed_info_fut, delete_fut) - } - .scoped() - }) - .await?; + let (follower_inbox_url, (followed_account, followed_user), _delete_result) = + with_connection!(self.db_pool, |db_conn| { + let follower_inbox_url_fut = accounts::table + .find(follow.follower_id) + .select(accounts::inbox_url.assume_not_null()) + .get_result::(db_conn); + + let followed_info_fut = accounts::table + .find(follow.account_id) + .inner_join(users::table.on(accounts::id.eq(users::account_id))) + .select(<(Account, User)>::as_select()) + .get_result::<(Account, User)>(db_conn); + + let delete_fut = diesel::delete(&follow).execute(db_conn); + + try_join!(follower_inbox_url_fut, followed_info_fut, delete_fut) + })?; let followed_account_url = self.service.url.user_url(followed_account.id); @@ -298,27 +264,21 @@ impl Deliverer { } async fn unfavourite(&self, favourite: Favourite) -> Result<()> { - let ((account, user), inbox_url) = self - .db_pool - .with_connection(|db_conn| { - async move { - let account_user_fut = accounts::table - .find(favourite.account_id) - .inner_join(users::table) - .select(<(Account, User)>::as_select()) - .get_result(db_conn); - - let inbox_url_fut = posts::table - .find(favourite.post_id) - .inner_join(accounts::table) - .select(accounts::inbox_url) - .get_result::>(db_conn); - - try_join!(account_user_fut, inbox_url_fut) - } - .scoped() - }) - .await?; + let ((account, user), inbox_url) = with_connection!(self.db_pool, |db_conn| { + let account_user_fut = accounts::table + .find(favourite.account_id) + .inner_join(users::table) + .select(<(Account, User)>::as_select()) + .get_result(db_conn); + + let inbox_url_fut = posts::table + .find(favourite.post_id) + .inner_join(accounts::table) + .select(accounts::inbox_url) + .get_result::>(db_conn); + + try_join!(account_user_fut, inbox_url_fut) + })?; if let Some(ref inbox_url) = inbox_url { let activity = favourite.into_negate_activity(self.mapping_state()).await?; @@ -331,26 +291,21 @@ impl Deliverer { } async fn unfollow(&self, follow: Follow) -> Result<()> { - let ((follower, follower_user), followed_account_inbox_url) = self - .db_pool - .with_connection(|db_conn| { - async { - let follower_info_fut = accounts::table - .find(follow.follower_id) - .inner_join(users::table) - .select(<(Account, User)>::as_select()) - .get_result::<(Account, User)>(db_conn); - - let followed_account_inbox_url_fut = accounts::table - .find(follow.account_id) - .select(accounts::inbox_url) - .get_result::>(db_conn); - - try_join!(follower_info_fut, followed_account_inbox_url_fut) - } - .scoped() - }) - .await?; + let ((follower, follower_user), followed_account_inbox_url) = + with_connection!(self.db_pool, |db_conn| { + let follower_info_fut = accounts::table + .find(follow.follower_id) + .inner_join(users::table) + .select(<(Account, User)>::as_select()) + .get_result::<(Account, User)>(db_conn); + + let followed_account_inbox_url_fut = accounts::table + .find(follow.account_id) + .select(accounts::inbox_url) + .get_result::>(db_conn); + + try_join!(follower_info_fut, followed_account_inbox_url_fut) + })?; if let Some(ref followed_account_inbox_url) = followed_account_inbox_url { let follow_activity = follow.into_negate_activity(self.mapping_state()).await?; @@ -369,20 +324,14 @@ impl Deliverer { } async fn update_account(&self, account: Account) -> Result<()> { - let user = self - .db_pool - .with_connection(|db_conn| { - async move { - users::table - .filter(users::account_id.eq(account.id)) - .select(User::as_select()) - .get_result(db_conn) - .await - .optional() - } - .scoped() - }) - .await?; + let user = with_connection!(self.db_pool, |db_conn| { + users::table + .filter(users::account_id.eq(account.id)) + .select(User::as_select()) + .get_result(db_conn) + .await + .optional() + })?; let Some(user) = user else { return Ok(()); @@ -404,22 +353,16 @@ impl Deliverer { } async fn update_post(&self, post: Post) -> Result<()> { - let post_account_user_data = self - .db_pool - .with_connection(|db_conn| { - async move { - posts::table - .find(post.id) - .inner_join(accounts::table) - .inner_join(users::table.on(accounts::id.eq(users::account_id))) - .select(<(Account, User)>::as_select()) - .get_result(db_conn) - .await - .optional() - } - .scoped() - }) - .await?; + let post_account_user_data = with_connection!(self.db_pool, |db_conn| { + posts::table + .find(post.id) + .inner_join(accounts::table) + .inner_join(users::table.on(accounts::id.eq(users::account_id))) + .select(<(Account, User)>::as_select()) + .get_result(db_conn) + .await + .optional() + })?; let Some((account, user)) = post_account_user_data else { return Ok(()); @@ -447,7 +390,7 @@ impl Deliverer { #[async_trait] impl DelivererTrait for Deliverer { - async fn deliver(&self, action: Action) -> Result<(), BoxError> { + async fn deliver(&self, action: Action) -> eyre::Result<()> { match action { Action::AcceptFollow(follow) => self.accept_follow(follow).await, Action::Create(post) | Action::Repost(post) => self.create_or_repost(post).await, diff --git a/crates/kitsune-activitypub/src/error.rs b/crates/kitsune-activitypub/src/error.rs index c03f70c2c..51b5cfb87 100644 --- a/crates/kitsune-activitypub/src/error.rs +++ b/crates/kitsune-activitypub/src/error.rs @@ -1,10 +1,6 @@ use diesel_async::pooled_connection::bb8; -use kitsune_core::error::BoxError; use rsa::pkcs8::der; -use std::{ - convert::Infallible, - fmt::{Debug, Display}, -}; +use std::{convert::Infallible, fmt::Debug}; use thiserror::Error; pub type Result = std::result::Result; @@ -33,13 +29,13 @@ pub enum Error { FederationFilter(#[from] kitsune_federation_filter::error::Error), #[error(transparent)] - FetchAccount(BoxError), + FetchAccount(eyre::Report), #[error(transparent)] - FetchEmoji(BoxError), + FetchEmoji(eyre::Report), #[error(transparent)] - FetchPost(BoxError), + FetchPost(eyre::Report), #[error(transparent)] Http(#[from] http::Error), @@ -66,7 +62,7 @@ pub enum Error { NotFound, #[error(transparent)] - Resolver(BoxError), + Resolver(eyre::Report), #[error(transparent)] Search(#[from] kitsune_search::Error), @@ -89,15 +85,3 @@ impl From for Error { match err {} } } - -impl From> for Error -where - E: Into + Debug + Display, -{ - fn from(value: kitsune_db::PoolError) -> Self { - match value { - kitsune_db::PoolError::Pool(err) => err.into(), - kitsune_db::PoolError::User(err) => err.into(), - } - } -} diff --git a/crates/kitsune-activitypub/src/fetcher/actor.rs b/crates/kitsune-activitypub/src/fetcher/actor.rs index cd926ce01..472e6b318 100644 --- a/crates/kitsune-activitypub/src/fetcher/actor.rs +++ b/crates/kitsune-activitypub/src/fetcher/actor.rs @@ -11,11 +11,11 @@ use kitsune_core::traits::fetcher::AccountFetchOptions; use kitsune_db::{ model::account::{Account, AccountConflictChangeset, NewAccount, UpdateAccountMedia}, schema::accounts, + with_connection, with_transaction, }; use kitsune_search::SearchBackend; use kitsune_type::ap::actor::Actor; use kitsune_util::{convert::timestamp_to_uuid, sanitize::CleanHtmlExt}; -use scoped_futures::ScopedFutureExt; use url::Url; impl Fetcher { @@ -36,20 +36,14 @@ impl Fetcher { return Ok(Some(user)); } - let user_data = self - .db_pool - .with_connection(|db_conn| { - async move { - accounts::table - .filter(accounts::url.eq(opts.url)) - .select(Account::as_select()) - .first(db_conn) - .await - .optional() - } - .scoped() - }) - .await?; + let user_data = with_connection!(self.db_pool, |db_conn| { + accounts::table + .filter(accounts::url.eq(opts.url)) + .select(Account::as_select()) + .first(db_conn) + .await + .optional() + })?; if let Some(user) = user_data { return Ok(Some(user)); @@ -96,93 +90,87 @@ impl Fetcher { actor.clean_html(); - let account: Account = self - .db_pool - .with_transaction(|tx| { - async move { - let account = diesel::insert_into(accounts::table) - .values(NewAccount { - id: timestamp_to_uuid(actor.published), - display_name: actor.name.as_deref(), - note: actor.subject.as_deref(), - username: actor.preferred_username.as_str(), - locked: actor.manually_approves_followers, - local: false, - domain, - actor_type: actor.r#type.into(), - url: actor.id.as_str(), - featured_collection_url: actor.featured.as_deref(), - followers_url: actor.followers.as_deref(), - following_url: actor.following.as_deref(), - inbox_url: Some(actor.inbox.as_str()), - outbox_url: actor.outbox.as_deref(), - shared_inbox_url: actor - .endpoints - .and_then(|endpoints| endpoints.shared_inbox) - .as_deref(), - public_key_id: actor.public_key.id.as_str(), - public_key: actor.public_key.public_key_pem.as_str(), - created_at: Some(actor.published), - }) - .on_conflict(accounts::url) - .do_update() - .set(AccountConflictChangeset { - display_name: actor.name.as_deref(), - note: actor.subject.as_deref(), - locked: actor.manually_approves_followers, - public_key_id: actor.public_key.id.as_str(), - public_key: actor.public_key.public_key_pem.as_str(), - }) + let account: Account = with_transaction!(self.db_pool, |tx| { + let account = diesel::insert_into(accounts::table) + .values(NewAccount { + id: timestamp_to_uuid(actor.published), + display_name: actor.name.as_deref(), + note: actor.subject.as_deref(), + username: actor.preferred_username.as_str(), + locked: actor.manually_approves_followers, + local: false, + domain, + actor_type: actor.r#type.into(), + url: actor.id.as_str(), + featured_collection_url: actor.featured.as_deref(), + followers_url: actor.followers.as_deref(), + following_url: actor.following.as_deref(), + inbox_url: Some(actor.inbox.as_str()), + outbox_url: actor.outbox.as_deref(), + shared_inbox_url: actor + .endpoints + .and_then(|endpoints| endpoints.shared_inbox) + .as_deref(), + public_key_id: actor.public_key.id.as_str(), + public_key: actor.public_key.public_key_pem.as_str(), + created_at: Some(actor.published), + }) + .on_conflict(accounts::url) + .do_update() + .set(AccountConflictChangeset { + display_name: actor.name.as_deref(), + note: actor.subject.as_deref(), + locked: actor.manually_approves_followers, + public_key_id: actor.public_key.id.as_str(), + public_key: actor.public_key.public_key_pem.as_str(), + }) + .returning(Account::as_returning()) + .get_result::(tx) + .await?; + + let avatar_id = if let Some(icon) = actor.icon { + process_attachments(tx, &account, &[icon]).await?.pop() + } else { + None + }; + + let header_id = if let Some(image) = actor.image { + process_attachments(tx, &account, &[image]).await?.pop() + } else { + None + }; + + let mut update_changeset = UpdateAccountMedia::default(); + if let Some(avatar_id) = avatar_id { + update_changeset = UpdateAccountMedia { + avatar_id: Some(avatar_id), + ..update_changeset + }; + } + + if let Some(header_id) = header_id { + update_changeset = UpdateAccountMedia { + header_id: Some(header_id), + ..update_changeset + }; + } + + let account = match update_changeset { + UpdateAccountMedia { + avatar_id: None, + header_id: None, + } => account, + _ => { + diesel::update(&account) + .set(update_changeset) .returning(Account::as_returning()) - .get_result::(tx) - .await?; - - let avatar_id = if let Some(icon) = actor.icon { - process_attachments(tx, &account, &[icon]).await?.pop() - } else { - None - }; - - let header_id = if let Some(image) = actor.image { - process_attachments(tx, &account, &[image]).await?.pop() - } else { - None - }; - - let mut update_changeset = UpdateAccountMedia::default(); - if let Some(avatar_id) = avatar_id { - update_changeset = UpdateAccountMedia { - avatar_id: Some(avatar_id), - ..update_changeset - }; - } - - if let Some(header_id) = header_id { - update_changeset = UpdateAccountMedia { - header_id: Some(header_id), - ..update_changeset - }; - } - - let account = match update_changeset { - UpdateAccountMedia { - avatar_id: None, - header_id: None, - } => account, - _ => { - diesel::update(&account) - .set(update_changeset) - .returning(Account::as_returning()) - .get_result(tx) - .await? - } - }; - - Ok::<_, Error>(account) + .get_result(tx) + .await? } - .scoped() - }) - .await?; + }; + + Ok::<_, Error>(account) + })?; self.search_backend .add_to_index(account.clone().into()) diff --git a/crates/kitsune-activitypub/src/fetcher/emoji.rs b/crates/kitsune-activitypub/src/fetcher/emoji.rs index 37ef484f7..ba340b299 100644 --- a/crates/kitsune-activitypub/src/fetcher/emoji.rs +++ b/crates/kitsune-activitypub/src/fetcher/emoji.rs @@ -9,28 +9,22 @@ use kitsune_db::{ media_attachment::{MediaAttachment, NewMediaAttachment}, }, schema::{custom_emojis, media_attachments}, + with_connection, with_transaction, }; use kitsune_type::ap::emoji::Emoji; -use scoped_futures::ScopedFutureExt; use speedy_uuid::Uuid; use url::Url; impl Fetcher { pub(crate) async fn fetch_emoji(&self, url: &str) -> Result> { - let existing_emoji = self - .db_pool - .with_connection(|db_conn| { - async move { - custom_emojis::table - .filter(custom_emojis::remote_id.eq(url)) - .select(CustomEmoji::as_select()) - .first(db_conn) - .await - .optional() - } - .scoped() - }) - .await?; + let existing_emoji = with_connection!(self.db_pool, |db_conn| { + custom_emojis::table + .filter(custom_emojis::remote_id.eq(url)) + .select(CustomEmoji::as_select()) + .first(db_conn) + .await + .optional() + })?; if let Some(emoji) = existing_emoji { return Ok(Some(emoji)); @@ -57,42 +51,36 @@ impl Fetcher { let name_pure = emoji.name.replace(':', ""); - let emoji: CustomEmoji = self - .db_pool - .with_transaction(|tx| { - async move { - let media_attachment = diesel::insert_into(media_attachments::table) - .values(NewMediaAttachment { - id: Uuid::now_v7(), - account_id: None, - content_type, - description: None, - blurhash: None, - file_path: None, - remote_url: Some(&emoji.icon.url), - }) - .returning(MediaAttachment::as_returning()) - .get_result::(tx) - .await?; - let emoji = diesel::insert_into(custom_emojis::table) - .values(CustomEmoji { - id: Uuid::now_v7(), - remote_id: emoji.id, - shortcode: name_pure.to_string(), - domain: Some(domain.to_string()), - media_attachment_id: media_attachment.id, - endorsed: false, - created_at: Timestamp::now_utc(), - updated_at: Timestamp::now_utc(), - }) - .returning(CustomEmoji::as_returning()) - .get_result::(tx) - .await?; - Ok::<_, Error>(emoji) - } - .scoped() - }) - .await?; + let emoji: CustomEmoji = with_transaction!(self.db_pool, |tx| { + let media_attachment = diesel::insert_into(media_attachments::table) + .values(NewMediaAttachment { + id: Uuid::now_v7(), + account_id: None, + content_type, + description: None, + blurhash: None, + file_path: None, + remote_url: Some(&emoji.icon.url), + }) + .returning(MediaAttachment::as_returning()) + .get_result::(tx) + .await?; + let emoji = diesel::insert_into(custom_emojis::table) + .values(CustomEmoji { + id: Uuid::now_v7(), + remote_id: emoji.id, + shortcode: name_pure.to_string(), + domain: Some(domain.to_string()), + media_attachment_id: media_attachment.id, + endorsed: false, + created_at: Timestamp::now_utc(), + updated_at: Timestamp::now_utc(), + }) + .returning(CustomEmoji::as_returning()) + .get_result::(tx) + .await?; + Ok::<_, Error>(emoji) + })?; Ok(Some(emoji)) } diff --git a/crates/kitsune-activitypub/src/fetcher/mod.rs b/crates/kitsune-activitypub/src/fetcher/mod.rs index e62ecf649..d481c078c 100644 --- a/crates/kitsune-activitypub/src/fetcher/mod.rs +++ b/crates/kitsune-activitypub/src/fetcher/mod.rs @@ -6,7 +6,6 @@ use kitsune_cache::ArcCache; use kitsune_config::language_detection::Configuration as LanguageDetectionConfig; use kitsune_core::{ consts::USER_AGENT, - error::BoxError, traits::{ fetcher::{AccountFetchOptions, PostFetchOptions}, Fetcher as FetcherTrait, Resolver, @@ -124,18 +123,15 @@ impl FetcherTrait for Fetcher { Arc::new(self.resolver.clone()) } - async fn fetch_account( - &self, - opts: AccountFetchOptions<'_>, - ) -> Result, BoxError> { + async fn fetch_account(&self, opts: AccountFetchOptions<'_>) -> eyre::Result> { Ok(self.fetch_actor(opts).await?) } - async fn fetch_emoji(&self, url: &str) -> Result, BoxError> { + async fn fetch_emoji(&self, url: &str) -> eyre::Result> { Ok(self.fetch_emoji(url).await?) } - async fn fetch_post(&self, opts: PostFetchOptions<'_>) -> Result, BoxError> { + async fn fetch_post(&self, opts: PostFetchOptions<'_>) -> eyre::Result> { Ok(self.fetch_object(opts.url, opts.call_depth).await?) } } diff --git a/crates/kitsune-activitypub/src/fetcher/object.rs b/crates/kitsune-activitypub/src/fetcher/object.rs index 3a8405db0..3b7817e4e 100644 --- a/crates/kitsune-activitypub/src/fetcher/object.rs +++ b/crates/kitsune-activitypub/src/fetcher/object.rs @@ -4,8 +4,7 @@ use autometrics::autometrics; use diesel::{ExpressionMethods, OptionalExtension, QueryDsl, SelectableHelper}; use diesel_async::RunQueryDsl; use kitsune_cache::CacheBackend; -use kitsune_db::{model::post::Post, schema::posts}; -use scoped_futures::ScopedFutureExt; +use kitsune_db::{model::post::Post, schema::posts, with_connection}; // Maximum call depth of fetching new posts. Prevents unbounded recursion. // Setting this to >=40 would cause the `fetch_infinitely_long_reply_chain` test to run into stack overflow @@ -23,20 +22,14 @@ impl Fetcher { return Ok(Some(post)); } - let post = self - .db_pool - .with_connection(|db_conn| { - async move { - posts::table - .filter(posts::url.eq(url)) - .select(Post::as_select()) - .first(db_conn) - .await - .optional() - } - .scoped() - }) - .await?; + let post = with_connection!(self.db_pool, |db_conn| { + posts::table + .filter(posts::url.eq(url)) + .select(Post::as_select()) + .first(db_conn) + .await + .optional() + })?; if let Some(post) = post { self.post_cache.set(url, &post).await?; diff --git a/crates/kitsune-activitypub/src/inbox_resolver.rs b/crates/kitsune-activitypub/src/inbox_resolver.rs index acd8a362b..6fc2a7259 100644 --- a/crates/kitsune-activitypub/src/inbox_resolver.rs +++ b/crates/kitsune-activitypub/src/inbox_resolver.rs @@ -13,9 +13,8 @@ use kitsune_db::{ post::{Post, Visibility}, }, schema::{accounts, accounts_follows}, - PgPool, + with_connection, PgPool, }; -use scoped_futures::ScopedFutureExt; pub struct InboxResolver { db_pool: PgPool, @@ -32,27 +31,25 @@ impl InboxResolver { &self, account: &Account, ) -> Result> + Send + '_> { - self.db_pool - .with_connection(|db_conn| { - accounts_follows::table - .filter(accounts_follows::account_id.eq(account.id)) - .inner_join( - accounts::table.on(accounts::id.eq(accounts_follows::follower_id).and( - accounts::inbox_url - .is_not_null() - .or(accounts::shared_inbox_url.is_not_null()), - )), - ) - .distinct() - .select(coalesce_nullable( - accounts::shared_inbox_url, - accounts::inbox_url, - )) - .load_stream(db_conn) - .scoped() - }) - .await - .map_err(Error::from) + with_connection!(self.db_pool, |db_conn| { + accounts_follows::table + .filter(accounts_follows::account_id.eq(account.id)) + .inner_join( + accounts::table.on(accounts::id.eq(accounts_follows::follower_id).and( + accounts::inbox_url + .is_not_null() + .or(accounts::shared_inbox_url.is_not_null()), + )), + ) + .distinct() + .select(coalesce_nullable( + accounts::shared_inbox_url, + accounts::inbox_url, + )) + .load_stream(db_conn) + .await + }) + .map_err(Error::from) } #[instrument(skip_all, fields(post_id = %post.id))] @@ -60,35 +57,29 @@ impl InboxResolver { &self, post: &Post, ) -> Result> + Send + '_> { - let (account, mentioned_inbox_stream) = self - .db_pool - .with_connection(|db_conn| { - async move { - let account = accounts::table - .find(post.account_id) - .select(Account::as_select()) - .first(db_conn) - .await?; + let (account, mentioned_inbox_stream) = with_connection!(self.db_pool, |db_conn| { + let account = accounts::table + .find(post.account_id) + .select(Account::as_select()) + .first(db_conn) + .await?; - let mentioned_inbox_stream = Mention::belonging_to(post) - .inner_join(accounts::table) - .filter( - accounts::shared_inbox_url - .is_not_null() - .or(accounts::inbox_url.is_not_null()), - ) - .select(coalesce_nullable( - accounts::shared_inbox_url, - accounts::inbox_url, - )) - .load_stream(db_conn) - .await?; + let mentioned_inbox_stream = Mention::belonging_to(post) + .inner_join(accounts::table) + .filter( + accounts::shared_inbox_url + .is_not_null() + .or(accounts::inbox_url.is_not_null()), + ) + .select(coalesce_nullable( + accounts::shared_inbox_url, + accounts::inbox_url, + )) + .load_stream(db_conn) + .await?; - Ok::<_, Error>((account, mentioned_inbox_stream)) - } - .scoped() - }) - .await?; + Ok::<_, Error>((account, mentioned_inbox_stream)) + })?; let stream = if post.visibility == Visibility::MentionOnly { Either::Left(mentioned_inbox_stream) diff --git a/crates/kitsune-activitypub/src/lib.rs b/crates/kitsune-activitypub/src/lib.rs index 8222ddcbc..76923d0bd 100644 --- a/crates/kitsune-activitypub/src/lib.rs +++ b/crates/kitsune-activitypub/src/lib.rs @@ -20,14 +20,13 @@ use kitsune_db::{ schema::{ media_attachments, posts, posts_custom_emojis, posts_media_attachments, posts_mentions, }, - PgPool, + with_transaction, PgPool, }; use kitsune_embed::Client as EmbedClient; use kitsune_language::Language; use kitsune_search::{AnySearchBackend, SearchBackend}; use kitsune_type::ap::{object::MediaAttachment, Object, Tag, TagType}; use kitsune_util::{convert::timestamp_to_uuid, process, sanitize::CleanHtmlExt, CowBox}; -use scoped_futures::ScopedFutureExt; use speedy_uuid::Uuid; use typed_builder::TypedBuilder; @@ -288,58 +287,53 @@ pub async fn process_new_object(process_data: ProcessNewObject<'_>) -> Result(tx) + .await?; + + let attachment_ids = process_attachments(tx, &user, &object.attachment).await?; + diesel::insert_into(posts_media_attachments::table) + .values( + attachment_ids + .into_iter() + .map(|attachment_id| NewPostMediaAttachment { + post_id: new_post.id, + media_attachment_id: attachment_id, }) - .on_conflict(posts::url) - .do_update() - .set(PostConflictChangeset { - subject: object.summary.as_deref(), - content: object.content.as_str(), - }) - .returning(Post::as_returning()) - .get_result::(tx) - .await?; - - let attachment_ids = process_attachments(tx, &user, &object.attachment).await?; - diesel::insert_into(posts_media_attachments::table) - .values( - attachment_ids - .into_iter() - .map(|attachment_id| NewPostMediaAttachment { - post_id: new_post.id, - media_attachment_id: attachment_id, - }) - .collect::>(), - ) - .execute(tx) - .await?; - - handle_mentions(tx, &user, new_post.id, &object.tag).await?; - handle_custom_emojis(tx, new_post.id, fetcher, &object.tag).await?; - - Ok::<_, Error>(new_post) - } - .scoped() - }) - .await?; + .collect::>(), + ) + .execute(tx) + .await?; + + handle_mentions(tx, &user, new_post.id, &object.tag).await?; + handle_custom_emojis(tx, new_post.id, fetcher, &object.tag).await?; + + Ok::<_, Error>(new_post) + })?; if post.visibility == Visibility::Public || post.visibility == Visibility::Unlisted { search_backend.add_to_index(post.clone().into()).await?; @@ -362,51 +356,46 @@ pub async fn update_object(process_data: ProcessNewObject<'_>) -> Result { search_backend, } = preprocess_object(process_data).await?; - let post = db_pool - .with_transaction(|tx| { - async move { - let updated_post = diesel::update(posts::table) - .filter(posts::url.eq(object.id.as_str())) - .set(FullPostChangeset { - account_id: user.id, - in_reply_to_id, - reposted_post_id: None, - subject: object.summary.as_deref(), - content: object.content.as_str(), - content_source: "", - content_lang: content_lang.into(), - link_preview_url: link_preview_url.as_deref(), - is_sensitive: object.sensitive, - visibility, - is_local: false, - updated_at: Timestamp::now_utc(), + let post = with_transaction!(db_pool, |tx| { + let updated_post = diesel::update(posts::table) + .filter(posts::url.eq(object.id.as_str())) + .set(FullPostChangeset { + account_id: user.id, + in_reply_to_id, + reposted_post_id: None, + subject: object.summary.as_deref(), + content: object.content.as_str(), + content_source: "", + content_lang: content_lang.into(), + link_preview_url: link_preview_url.as_deref(), + is_sensitive: object.sensitive, + visibility, + is_local: false, + updated_at: Timestamp::now_utc(), + }) + .returning(Post::as_returning()) + .get_result::(tx) + .await?; + + let attachment_ids = process_attachments(tx, &user, &object.attachment).await?; + diesel::insert_into(posts_media_attachments::table) + .values( + attachment_ids + .into_iter() + .map(|attachment_id| NewPostMediaAttachment { + post_id: updated_post.id, + media_attachment_id: attachment_id, }) - .returning(Post::as_returning()) - .get_result::(tx) - .await?; - - let attachment_ids = process_attachments(tx, &user, &object.attachment).await?; - diesel::insert_into(posts_media_attachments::table) - .values( - attachment_ids - .into_iter() - .map(|attachment_id| NewPostMediaAttachment { - post_id: updated_post.id, - media_attachment_id: attachment_id, - }) - .collect::>(), - ) - .on_conflict_do_nothing() - .execute(tx) - .await?; - - handle_mentions(tx, &user, updated_post.id, &object.tag).await?; - - Ok::<_, Error>(updated_post) - } - .scoped() - }) - .await?; + .collect::>(), + ) + .on_conflict_do_nothing() + .execute(tx) + .await?; + + handle_mentions(tx, &user, updated_post.id, &object.tag).await?; + + Ok::<_, Error>(updated_post) + })?; if post.visibility == Visibility::Public || post.visibility == Visibility::Unlisted { search_backend.update_in_index(post.clone().into()).await?; diff --git a/crates/kitsune-activitypub/src/mapping/activity.rs b/crates/kitsune-activitypub/src/mapping/activity.rs index aae38deca..87d152964 100644 --- a/crates/kitsune-activitypub/src/mapping/activity.rs +++ b/crates/kitsune-activitypub/src/mapping/activity.rs @@ -6,10 +6,10 @@ use iso8601_timestamp::Timestamp; use kitsune_db::{ model::{account::Account, favourite::Favourite, follower::Follow, post::Post}, schema::{accounts, posts}, + with_connection, }; use kitsune_type::ap::{ap_context, Activity, ActivityType, ObjectField}; use kitsune_util::try_join; -use scoped_futures::ScopedFutureExt; use std::future::Future; pub trait IntoActivity { @@ -50,25 +50,19 @@ impl IntoActivity for Favourite { type NegateOutput = Activity; async fn into_activity(self, state: State<'_>) -> Result { - let (account_url, post_url) = state - .db_pool - .with_connection(|db_conn| { - async move { - let account_url_fut = accounts::table - .find(self.account_id) - .select(accounts::url) - .get_result::(db_conn); - - let post_url_fut = posts::table - .find(self.post_id) - .select(posts::url) - .get_result(db_conn); - - try_join!(account_url_fut, post_url_fut) - } - .scoped() - }) - .await?; + let (account_url, post_url) = with_connection!(state.db_pool, |db_conn| { + let account_url_fut = accounts::table + .find(self.account_id) + .select(accounts::url) + .get_result::(db_conn); + + let post_url_fut = posts::table + .find(self.post_id) + .select(posts::url) + .get_result(db_conn); + + try_join!(account_url_fut, post_url_fut) + })?; Ok(Activity { context: ap_context(), @@ -81,16 +75,13 @@ impl IntoActivity for Favourite { } async fn into_negate_activity(self, state: State<'_>) -> Result { - let account_url = state - .db_pool - .with_connection(|db_conn| { - accounts::table - .find(self.account_id) - .select(accounts::url) - .get_result::(db_conn) - .scoped() - }) - .await?; + let account_url = with_connection!(state.db_pool, |db_conn| { + accounts::table + .find(self.account_id) + .select(accounts::url) + .get_result::(db_conn) + .await + })?; Ok(Activity { context: ap_context(), @@ -108,25 +99,19 @@ impl IntoActivity for Follow { type NegateOutput = Activity; async fn into_activity(self, state: State<'_>) -> Result { - let (attributed_to, object) = state - .db_pool - .with_connection(|db_conn| { - async move { - let attributed_to_fut = accounts::table - .find(self.follower_id) - .select(accounts::url) - .get_result::(db_conn); - - let object_fut = accounts::table - .find(self.account_id) - .select(accounts::url) - .get_result::(db_conn); - - try_join!(attributed_to_fut, object_fut) - } - .scoped() - }) - .await?; + let (attributed_to, object) = with_connection!(state.db_pool, |db_conn| { + let attributed_to_fut = accounts::table + .find(self.follower_id) + .select(accounts::url) + .get_result::(db_conn); + + let object_fut = accounts::table + .find(self.account_id) + .select(accounts::url) + .get_result::(db_conn); + + try_join!(attributed_to_fut, object_fut) + })?; Ok(Activity { context: ap_context(), @@ -139,16 +124,13 @@ impl IntoActivity for Follow { } async fn into_negate_activity(self, state: State<'_>) -> Result { - let attributed_to = state - .db_pool - .with_connection(|db_conn| { - accounts::table - .find(self.follower_id) - .select(accounts::url) - .get_result::(db_conn) - .scoped() - }) - .await?; + let attributed_to = with_connection!(state.db_pool, |db_conn| { + accounts::table + .find(self.follower_id) + .select(accounts::url) + .get_result::(db_conn) + .await + })?; Ok(Activity { context: ap_context(), @@ -169,16 +151,13 @@ impl IntoActivity for Post { let account_url = state.service.url.user_url(self.account_id); if let Some(reposted_post_id) = self.reposted_post_id { - let reposted_post_url = state - .db_pool - .with_connection(|db_conn| { - posts::table - .find(reposted_post_id) - .select(posts::url) - .get_result(db_conn) - .scoped() - }) - .await?; + let reposted_post_url = with_connection!(state.db_pool, |db_conn| { + posts::table + .find(reposted_post_id) + .select(posts::url) + .get_result(db_conn) + .await + })?; Ok(Activity { context: ap_context(), diff --git a/crates/kitsune-activitypub/src/mapping/object.rs b/crates/kitsune-activitypub/src/mapping/object.rs index 37b9e3ff1..e099fd46a 100644 --- a/crates/kitsune-activitypub/src/mapping/object.rs +++ b/crates/kitsune-activitypub/src/mapping/object.rs @@ -12,6 +12,7 @@ use kitsune_db::{ post::Post, }, schema::{accounts, custom_emojis, media_attachments, posts, posts_custom_emojis}, + with_connection, }; use kitsune_type::ap::{ actor::{Actor, PublicKey}, @@ -22,7 +23,6 @@ use kitsune_type::ap::{ }; use kitsune_util::try_join; use mime::Mime; -use scoped_futures::ScopedFutureExt; use std::{future::Future, str::FromStr}; pub trait IntoObject { @@ -101,56 +101,51 @@ impl IntoObject for Post { return Err(Error::NotFound); } - let (account, in_reply_to, mentions, emojis, attachment_stream) = state - .db_pool - .with_connection(|db_conn| { - async { - let account_fut = accounts::table - .find(self.account_id) - .select(Account::as_select()) - .get_result(db_conn); + let (account, in_reply_to, mentions, emojis, attachment_stream) = + with_connection!(state.db_pool, |db_conn| { + let account_fut = accounts::table + .find(self.account_id) + .select(Account::as_select()) + .get_result(db_conn); - let in_reply_to_fut = - OptionFuture::from(self.in_reply_to_id.map(|in_reply_to_id| { - posts::table - .find(in_reply_to_id) - .select(posts::url) - .get_result(db_conn) - })) - .map(Option::transpose); + let in_reply_to_fut = + OptionFuture::from(self.in_reply_to_id.map(|in_reply_to_id| { + posts::table + .find(in_reply_to_id) + .select(posts::url) + .get_result(db_conn) + })) + .map(Option::transpose); - let mentions_fut = Mention::belonging_to(&self) - .inner_join(accounts::table) - .select((Mention::as_select(), Account::as_select())) - .load::<(Mention, Account)>(db_conn); + let mentions_fut = Mention::belonging_to(&self) + .inner_join(accounts::table) + .select((Mention::as_select(), Account::as_select())) + .load::<(Mention, Account)>(db_conn); - let custom_emojis_fut = custom_emojis::table - .inner_join(posts_custom_emojis::table) - .inner_join(media_attachments::table) - .filter(posts_custom_emojis::post_id.eq(self.id)) - .select(( - CustomEmoji::as_select(), - PostCustomEmoji::as_select(), - DbMediaAttachment::as_select(), - )) - .load::<(CustomEmoji, PostCustomEmoji, DbMediaAttachment)>(db_conn); + let custom_emojis_fut = custom_emojis::table + .inner_join(posts_custom_emojis::table) + .inner_join(media_attachments::table) + .filter(posts_custom_emojis::post_id.eq(self.id)) + .select(( + CustomEmoji::as_select(), + PostCustomEmoji::as_select(), + DbMediaAttachment::as_select(), + )) + .load::<(CustomEmoji, PostCustomEmoji, DbMediaAttachment)>(db_conn); - let attachment_stream_fut = PostMediaAttachment::belonging_to(&self) - .inner_join(media_attachments::table) - .select(DbMediaAttachment::as_select()) - .load_stream::(db_conn); + let attachment_stream_fut = PostMediaAttachment::belonging_to(&self) + .inner_join(media_attachments::table) + .select(DbMediaAttachment::as_select()) + .load_stream::(db_conn); - try_join!( - account_fut, - in_reply_to_fut, - mentions_fut, - custom_emojis_fut, - attachment_stream_fut - ) - } - .scoped() - }) - .await?; + try_join!( + account_fut, + in_reply_to_fut, + mentions_fut, + custom_emojis_fut, + attachment_stream_fut + ) + })?; let attachment = attachment_stream .map_err(Error::from) @@ -197,34 +192,28 @@ impl IntoObject for Account { type Output = Actor; async fn into_object(self, state: State<'_>) -> Result { - let (icon, image) = state - .db_pool - .with_connection(|db_conn| { - async move { - // These calls also probably allocate two cocnnections. ugh. - let icon_fut = OptionFuture::from(self.avatar_id.map(|avatar_id| { - media_attachments::table - .find(avatar_id) - .get_result::(db_conn) - .map_err(Error::from) - .and_then(|media_attachment| media_attachment.into_object(state)) - })) - .map(Option::transpose); + let (icon, image) = with_connection!(state.db_pool, |db_conn| { + // These calls also probably allocate two cocnnections. ugh. + let icon_fut = OptionFuture::from(self.avatar_id.map(|avatar_id| { + media_attachments::table + .find(avatar_id) + .get_result::(db_conn) + .map_err(Error::from) + .and_then(|media_attachment| media_attachment.into_object(state)) + })) + .map(Option::transpose); - let image_fut = OptionFuture::from(self.header_id.map(|header_id| { - media_attachments::table - .find(header_id) - .get_result::(db_conn) - .map_err(Error::from) - .and_then(|media_attachment| media_attachment.into_object(state)) - })) - .map(Option::transpose); + let image_fut = OptionFuture::from(self.header_id.map(|header_id| { + media_attachments::table + .find(header_id) + .get_result::(db_conn) + .map_err(Error::from) + .and_then(|media_attachment| media_attachment.into_object(state)) + })) + .map(Option::transpose); - try_join!(icon_fut, image_fut) - } - .scoped() - }) - .await?; + try_join!(icon_fut, image_fut) + })?; let user_url = state.service.url.user_url(self.id); let inbox = state.service.url.inbox_url(self.id); @@ -269,17 +258,14 @@ impl IntoObject for CustomEmoji { Some(_) => Err(Error::NotFound), }?; - let icon = state - .db_pool - .with_connection(|db_conn| { - media_attachments::table - .find(self.media_attachment_id) - .get_result::(db_conn) - .map_err(Error::from) - .and_then(|media_attachment| media_attachment.into_object(state)) - .scoped() - }) - .await?; + let icon = with_connection!(state.db_pool, |db_conn| { + media_attachments::table + .find(self.media_attachment_id) + .get_result::(db_conn) + .map_err(Error::from) + .and_then(|media_attachment| media_attachment.into_object(state)) + .await + })?; Ok(Emoji { context: ap_context(), diff --git a/crates/kitsune-activitypub/tests/fetcher/basic.rs b/crates/kitsune-activitypub/tests/fetcher/basic.rs index 25489cd11..7c3fac8f0 100644 --- a/crates/kitsune-activitypub/tests/fetcher/basic.rs +++ b/crates/kitsune-activitypub/tests/fetcher/basic.rs @@ -8,6 +8,7 @@ use kitsune_core::traits::Fetcher as _; use kitsune_db::{ model::{account::Account, media_attachment::MediaAttachment}, schema::{accounts, media_attachments}, + with_connection_panicky, }; use kitsune_federation_filter::FederationFilter; use kitsune_http_client::Client; @@ -15,7 +16,6 @@ use kitsune_search::NoopSearchService; use kitsune_test::{database_test, language_detection_config}; use kitsune_webfinger::Webfinger; use pretty_assertions::assert_eq; -use scoped_futures::ScopedFutureExt; use std::sync::Arc; use tower::service_fn; @@ -92,15 +92,14 @@ async fn fetch_emoji() { assert_eq!(emoji.shortcode, "Blobhaj"); assert_eq!(emoji.domain, Some(String::from("corteximplant.com"))); - let media_attachment = db_pool - .with_connection(|db_conn| { + let media_attachment = + with_connection_panicky!(db_pool, |db_conn| { media_attachments::table .find(emoji.media_attachment_id) .select(MediaAttachment::as_select()) .get_result::(db_conn) - .scoped() + .await }) - .await .expect("Get media attachment"); assert_eq!( @@ -149,16 +148,14 @@ async fn fetch_note() { "https://corteximplant.com/users/0x0/statuses/109501674056556919" ); - let author = db_pool - .with_connection(|db_conn| { - accounts::table - .find(note.account_id) - .select(Account::as_select()) - .get_result::(db_conn) - .scoped() - }) - .await - .expect("Get author"); + let author = with_connection_panicky!(db_pool, |db_conn| { + accounts::table + .find(note.account_id) + .select(Account::as_select()) + .get_result::(db_conn) + .await + }) + .expect("Get author"); assert_eq!(author.username, "0x0"); assert_eq!(author.url, "https://corteximplant.com/users/0x0"); diff --git a/crates/kitsune-db/src/pool.rs b/crates/kitsune-db/src/pool.rs index a16a951c6..9eaff5c74 100644 --- a/crates/kitsune-db/src/pool.rs +++ b/crates/kitsune-db/src/pool.rs @@ -10,11 +10,10 @@ macro_rules! with_connection { #[macro_export] macro_rules! with_connection_panicky { ($pool:expr, $($other:tt)*) => {{ - let result: ::std::result::Result<_, Box> = async move { - let _ = $crate::with_connection!($pool, $($other)*); - Ok(()) + let result: ::std::result::Result<_, Box> = async { + Ok($crate::with_connection!($pool, $($other)*)) }.await; - result.unwrap(); + result.unwrap() }}; } diff --git a/crates/kitsune-jobs/Cargo.toml b/crates/kitsune-jobs/Cargo.toml index 497f10327..0ceccd7ef 100644 --- a/crates/kitsune-jobs/Cargo.toml +++ b/crates/kitsune-jobs/Cargo.toml @@ -15,7 +15,6 @@ futures-util = "0.3.30" kitsune-core = { path = "../kitsune-core" } kitsune-db = { path = "../kitsune-db" } kitsune-email = { path = "../kitsune-email" } -scoped-futures = "0.1.3" serde = { version = "1.0.197", features = ["derive"] } speedy-uuid = { path = "../../lib/speedy-uuid" } tracing = "0.1.40" diff --git a/crates/kitsune-mastodon/Cargo.toml b/crates/kitsune-mastodon/Cargo.toml index ac593fa88..557ca1e05 100644 --- a/crates/kitsune-mastodon/Cargo.toml +++ b/crates/kitsune-mastodon/Cargo.toml @@ -20,7 +20,6 @@ kitsune-type = { path = "../kitsune-type" } kitsune-url = { path = "../kitsune-url" } kitsune-util = { path = "../kitsune-util" } mime = "0.3.17" -scoped-futures = "0.1.3" serde = "1.0.197" simd-json = "0.13.9" smol_str = "0.2.1" diff --git a/crates/kitsune-mastodon/src/sealed.rs b/crates/kitsune-mastodon/src/sealed.rs index eb1f5e80b..f70c7b759 100644 --- a/crates/kitsune-mastodon/src/sealed.rs +++ b/crates/kitsune-mastodon/src/sealed.rs @@ -24,7 +24,7 @@ use kitsune_db::{ accounts, accounts_follows, custom_emojis, media_attachments, notifications, posts, posts_favourites, }, - PgPool, + with_connection, PgPool, }; use kitsune_embed::Client as EmbedClient; use kitsune_embed::{embed_sdk::EmbedType, Embed}; @@ -40,7 +40,6 @@ use kitsune_type::mastodon::{ use kitsune_url::UrlService; use kitsune_util::try_join; use mime::Mime; -use scoped_futures::ScopedFutureExt; use serde::{de::DeserializeOwned, Serialize}; use smol_str::SmolStr; use speedy_uuid::Uuid; @@ -78,30 +77,25 @@ impl IntoMastodon for DbAccount { } async fn into_mastodon(self, state: MapperState<'_>) -> Result { - let (statuses_count, followers_count, following_count) = state - .db_pool - .with_connection(|db_conn| { - async { - let statuses_count_fut = posts::table - .filter(posts::account_id.eq(self.id)) - .count() - .get_result::(db_conn); - - let followers_count_fut = accounts_follows::table - .filter(accounts_follows::account_id.eq(self.id)) - .count() - .get_result::(db_conn); - - let following_count_fut = accounts_follows::table - .filter(accounts_follows::follower_id.eq(self.id)) - .count() - .get_result::(db_conn); - - try_join!(statuses_count_fut, followers_count_fut, following_count_fut) - } - .scoped() - }) - .await?; + let (statuses_count, followers_count, following_count) = + with_connection!(state.db_pool, |db_conn| { + let statuses_count_fut = posts::table + .filter(posts::account_id.eq(self.id)) + .count() + .get_result::(db_conn); + + let followers_count_fut = accounts_follows::table + .filter(accounts_follows::account_id.eq(self.id)) + .count() + .get_result::(db_conn); + + let following_count_fut = accounts_follows::table + .filter(accounts_follows::follower_id.eq(self.id)) + .count() + .get_result::(db_conn); + + try_join!(statuses_count_fut, followers_count_fut, following_count_fut) + })?; let mut acct = self.username.clone(); if !self.local { @@ -163,39 +157,34 @@ impl IntoMastodon for (&DbAccount, &DbAccount) { async fn into_mastodon(self, state: MapperState<'_>) -> Result { let (requester, target) = self; - let ((following, follow_requested), followed_by) = state - .db_pool - .with_connection(|db_conn| { - async move { - let following_requested_fut = accounts_follows::table - .filter( - accounts_follows::account_id - .eq(target.id) - .and(accounts_follows::follower_id.eq(requester.id)), - ) - .get_result::(db_conn) - .map(OptionalExtension::optional) - .map_ok(|optional_follow| { - optional_follow.map_or((false, false), |follow| { - (follow.approved_at.is_some(), follow.approved_at.is_none()) - }) - }); - - let followed_by_fut = accounts_follows::table - .filter( - accounts_follows::account_id - .eq(requester.id) - .and(accounts_follows::follower_id.eq(target.id)), - ) - .count() - .get_result::(db_conn) - .map_ok(|count| count != 0); - - try_join!(following_requested_fut, followed_by_fut) - } - .scoped() - }) - .await?; + let ((following, follow_requested), followed_by) = + with_connection!(state.db_pool, |db_conn| { + let following_requested_fut = accounts_follows::table + .filter( + accounts_follows::account_id + .eq(target.id) + .and(accounts_follows::follower_id.eq(requester.id)), + ) + .get_result::(db_conn) + .map(OptionalExtension::optional) + .map_ok(|optional_follow| { + optional_follow.map_or((false, false), |follow| { + (follow.approved_at.is_some(), follow.approved_at.is_none()) + }) + }); + + let followed_by_fut = accounts_follows::table + .filter( + accounts_follows::account_id + .eq(requester.id) + .and(accounts_follows::follower_id.eq(target.id)), + ) + .count() + .get_result::(db_conn) + .map_ok(|count| count != 0); + + try_join!(following_requested_fut, followed_by_fut) + })?; Ok(Relationship { id: target.id, @@ -223,16 +212,13 @@ impl IntoMastodon for DbMention { } async fn into_mastodon(self, state: MapperState<'_>) -> Result { - let account: DbAccount = state - .db_pool - .with_connection(|db_conn| { - accounts::table - .find(self.account_id) - .select(DbAccount::as_select()) - .get_result(db_conn) - .scoped() - }) - .await?; + let account: DbAccount = with_connection!(state.db_pool, |db_conn| { + accounts::table + .find(self.account_id) + .select(DbAccount::as_select()) + .get_result(db_conn) + .await + })?; let mut acct = account.username.clone(); if !account.local { @@ -289,29 +275,23 @@ impl IntoMastodon for (&DbAccount, DbPost) { async fn into_mastodon(self, state: MapperState<'_>) -> Result { let (account, post) = self; - let (favourited, reblogged) = state - .db_pool - .with_connection(|db_conn| { - async move { - let favourited_fut = posts_favourites::table - .filter(posts_favourites::account_id.eq(account.id)) - .filter(posts_favourites::post_id.eq(post.id)) - .count() - .get_result::(db_conn) - .map_ok(|count| count != 0); - - let reblogged_fut = posts::table - .filter(posts::account_id.eq(account.id)) - .filter(posts::reposted_post_id.eq(post.id)) - .count() - .get_result::(db_conn) - .map_ok(|count| count != 0); - - try_join!(favourited_fut, reblogged_fut) - } - .scoped() - }) - .await?; + let (favourited, reblogged) = with_connection!(state.db_pool, |db_conn| { + let favourited_fut = posts_favourites::table + .filter(posts_favourites::account_id.eq(account.id)) + .filter(posts_favourites::post_id.eq(post.id)) + .count() + .get_result::(db_conn) + .map_ok(|count| count != 0); + + let reblogged_fut = posts::table + .filter(posts::account_id.eq(account.id)) + .filter(posts::reposted_post_id.eq(post.id)) + .count() + .get_result::(db_conn) + .map_ok(|count| count != 0); + + try_join!(favourited_fut, reblogged_fut) + })?; let mut status = post.into_mastodon(state).await?; status.favourited = favourited; @@ -341,62 +321,56 @@ impl IntoMastodon for DbPost { media_attachments, mentions_stream, custom_emojis_stream, - ) = state - .db_pool - .with_connection(|db_conn| { - async { - let account_fut = accounts::table - .find(self.account_id) - .select(DbAccount::as_select()) - .get_result::(db_conn) - .map_err(Error::from) - .and_then(|db_account| db_account.into_mastodon(state)); - - let reblog_count_fut = posts::table - .filter(posts::reposted_post_id.eq(self.id)) - .count() - .get_result::(db_conn) - .map_err(Error::from); - - let favourites_count_fut = DbFavourite::belonging_to(&self) - .count() - .get_result::(db_conn) - .map_err(Error::from); - - let media_attachments_fut = DbPostMediaAttachment::belonging_to(&self) - .inner_join(media_attachments::table) - .select(DbMediaAttachment::as_select()) - .load_stream::(db_conn) + ) = with_connection!(state.db_pool, |db_conn| { + let account_fut = accounts::table + .find(self.account_id) + .select(DbAccount::as_select()) + .get_result::(db_conn) + .map_err(Error::from) + .and_then(|db_account| db_account.into_mastodon(state)); + + let reblog_count_fut = posts::table + .filter(posts::reposted_post_id.eq(self.id)) + .count() + .get_result::(db_conn) + .map_err(Error::from); + + let favourites_count_fut = DbFavourite::belonging_to(&self) + .count() + .get_result::(db_conn) + .map_err(Error::from); + + let media_attachments_fut = DbPostMediaAttachment::belonging_to(&self) + .inner_join(media_attachments::table) + .select(DbMediaAttachment::as_select()) + .load_stream::(db_conn) + .map_err(Error::from) + .and_then(|attachment_stream| { + attachment_stream .map_err(Error::from) - .and_then(|attachment_stream| { - attachment_stream - .map_err(Error::from) - .and_then(|attachment| attachment.into_mastodon(state)) - .try_collect() - }); - - let mentions_stream_fut = DbMention::belonging_to(&self) - .load_stream::(db_conn) - .map_err(Error::from); - - let custom_emojis_stream_fut = DbPostCustomEmoji::belonging_to(&self) - .inner_join(custom_emojis::table.inner_join(media_attachments::table)) - .select((DbCustomEmoji::as_select(), DbMediaAttachment::as_select())) - .load_stream::<(DbCustomEmoji, DbMediaAttachment)>(db_conn) - .map_err(Error::from); - - try_join!( - account_fut, - reblog_count_fut, - favourites_count_fut, - media_attachments_fut, - mentions_stream_fut, - custom_emojis_stream_fut - ) - } - .scoped() - }) - .await?; + .and_then(|attachment| attachment.into_mastodon(state)) + .try_collect() + }); + + let mentions_stream_fut = DbMention::belonging_to(&self) + .load_stream::(db_conn) + .map_err(Error::from); + + let custom_emojis_stream_fut = DbPostCustomEmoji::belonging_to(&self) + .inner_join(custom_emojis::table.inner_join(media_attachments::table)) + .select((DbCustomEmoji::as_select(), DbMediaAttachment::as_select())) + .load_stream::<(DbCustomEmoji, DbMediaAttachment)>(db_conn) + .map_err(Error::from); + + try_join!( + account_fut, + reblog_count_fut, + favourites_count_fut, + media_attachments_fut, + mentions_stream_fut, + custom_emojis_stream_fut + ) + })?; let link_preview = OptionFuture::from( self.link_preview_url @@ -423,30 +397,24 @@ impl IntoMastodon for DbPost { .try_collect() .await?; - let reblog = state - .db_pool - .with_connection(|db_conn| { - async { - OptionFuture::from( - OptionFuture::from(self.reposted_post_id.map(|id| { - posts::table - .find(id) - .select(DbPost::as_select()) - .get_result::(db_conn) - .map(OptionalExtension::optional) - })) - .await - .transpose()? - .flatten() - .map(|post| post.into_mastodon(state)), // This will allocate two database connections. Fuck. - ) - .await - .transpose() - } - .scoped() - }) - .await? - .map(Box::new); + let reblog = with_connection!(state.db_pool, |db_conn| { + OptionFuture::from( + OptionFuture::from(self.reposted_post_id.map(|id| { + posts::table + .find(id) + .select(DbPost::as_select()) + .get_result::(db_conn) + .map(OptionalExtension::optional) + })) + .await + .transpose()? + .flatten() + .map(|post| post.into_mastodon(state)), // This will allocate two database connections. Fuck. + ) + .await + .transpose() + })? + .map(Box::new); let language = self.content_lang.to_639_1().map(str::to_string); @@ -581,9 +549,8 @@ impl IntoMastodon for DbNotification { } async fn into_mastodon(self, state: MapperState<'_>) -> Result { - let (notification, account, status): (DbNotification, DbAccount, Option) = state - .db_pool - .with_connection(|mut db_conn| { + let (notification, account, status): (DbNotification, DbAccount, Option) = + with_connection!(state.db_pool, |db_conn| { notifications::table .filter(notifications::receiving_account_id.eq(self.receiving_account_id)) .inner_join( @@ -592,10 +559,9 @@ impl IntoMastodon for DbNotification { ) .left_outer_join(posts::table) .select(<(DbNotification, DbAccount, Option)>::as_select()) - .get_result(&mut db_conn) - .scoped() - }) - .await?; + .get_result(db_conn) + .await + })?; let status = OptionFuture::from(status.map(|status| status.into_mastodon(state))) .await diff --git a/crates/kitsune-service/Cargo.toml b/crates/kitsune-service/Cargo.toml index 1550a684b..c119fddbb 100644 --- a/crates/kitsune-service/Cargo.toml +++ b/crates/kitsune-service/Cargo.toml @@ -54,7 +54,6 @@ redis = { version = "0.25.2", default-features = false, features = [ ] } rsa = "0.9.6" rusty-s3 = { version = "0.5.0", default-features = false } -scoped-futures = "0.1.3" serde = "1.0.197" simd-json = "0.13.9" smol_str = "0.2.1" diff --git a/crates/kitsune-service/src/attachment.rs b/crates/kitsune-service/src/attachment.rs index f4ba355f2..400e1d3d9 100644 --- a/crates/kitsune-service/src/attachment.rs +++ b/crates/kitsune-service/src/attachment.rs @@ -230,12 +230,9 @@ impl AttachmentService { #[cfg(test)] mod test { - use crate::{ - attachment::{AttachmentService, Upload}, - error::Error, - }; + use crate::attachment::{AttachmentService, Upload}; use bytes::{Bytes, BytesMut}; - use diesel_async::{AsyncConnection, AsyncPgConnection, RunQueryDsl}; + use diesel_async::{AsyncPgConnection, RunQueryDsl}; use futures_util::{future, pin_mut, stream, StreamExt}; use http::{Request, Response}; use http_body_util::Empty; @@ -250,12 +247,12 @@ mod test { media_attachment::MediaAttachment, }, schema::accounts, + with_connection_panicky, }; use kitsune_http_client::Client; use kitsune_storage::fs::Storage; use kitsune_test::database_test; use kitsune_url::UrlService; - use scoped_futures::ScopedFutureExt; use speedy_uuid::Uuid; use std::convert::Infallible; use tempfile::TempDir; @@ -266,12 +263,10 @@ mod test { database_test(|db_pool| async move { let client = Client::builder().service(service_fn(handle)); - let account_id = db_pool - .with_connection(|db_conn| { - async move { Ok::<_, eyre::Report>(prepare_db(db_conn).await) }.scoped() - }) - .await - .unwrap(); + let account_id = with_connection_panicky!(db_pool, |db_conn| { + Ok::<_, eyre::Report>(prepare_db(db_conn).await) + }) + .unwrap(); let temp_dir = TempDir::new().unwrap(); let storage = Storage::new(temp_dir.path().to_owned()); @@ -339,38 +334,32 @@ mod test { async fn prepare_db(db_conn: &mut AsyncPgConnection) -> Uuid { // Create a local user `@alice` - db_conn - .transaction(|tx| { - async move { - let account_id = Uuid::now_v7(); - diesel::insert_into(accounts::table) - .values(NewAccount { - id: account_id, - display_name: None, - username: "alice", - locked: false, - note: None, - local: true, - domain: "example.com", - actor_type: ActorType::Person, - url: "https://example.com/users/alice", - featured_collection_url: None, - followers_url: None, - following_url: None, - inbox_url: None, - outbox_url: None, - shared_inbox_url: None, - public_key_id: "https://example.com/users/alice#main-key", - public_key: "", - created_at: None, - }) - .execute(tx) - .await?; - Ok::<_, Error>(account_id) - } - .scope_boxed() + let account_id = Uuid::now_v7(); + diesel::insert_into(accounts::table) + .values(NewAccount { + id: account_id, + display_name: None, + username: "alice", + locked: false, + note: None, + local: true, + domain: "example.com", + actor_type: ActorType::Person, + url: "https://example.com/users/alice", + featured_collection_url: None, + followers_url: None, + following_url: None, + inbox_url: None, + outbox_url: None, + shared_inbox_url: None, + public_key_id: "https://example.com/users/alice#main-key", + public_key: "", + created_at: None, }) + .execute(db_conn) .await - .unwrap() + .unwrap(); + + account_id } } diff --git a/crates/kitsune-service/src/post/resolver.rs b/crates/kitsune-service/src/post/resolver.rs index aa6a35bcc..ecd404a1a 100644 --- a/crates/kitsune-service/src/post/resolver.rs +++ b/crates/kitsune-service/src/post/resolver.rs @@ -126,6 +126,7 @@ mod test { account::Account, custom_emoji::CustomEmoji, media_attachment::NewMediaAttachment, }, schema::{accounts, custom_emojis, media_attachments}, + with_connection_panicky, }; use kitsune_federation_filter::FederationFilter; use kitsune_http_client::Client; @@ -137,7 +138,6 @@ mod test { use kitsune_util::try_join; use kitsune_webfinger::Webfinger; use pretty_assertions::assert_eq; - use scoped_futures::ScopedFutureExt; use speedy_uuid::Uuid; use std::sync::Arc; use tower::service_fn; @@ -221,69 +221,63 @@ mod test { let emoji_ids = (Uuid::now_v7(), Uuid::now_v7()); let media_attachment_ids = (Uuid::now_v7(), Uuid::now_v7()); - db_pool - .with_connection(|db_conn| { - async { - let media_fut = diesel::insert_into(media_attachments::table) - .values(NewMediaAttachment { - id: media_attachment_ids.0, - content_type: "image/jpeg", - account_id: None, - description: None, - blurhash: None, - file_path: None, - remote_url: None, - }) - .execute(db_conn); - let emoji_fut = diesel::insert_into(custom_emojis::table) - .values(CustomEmoji { - id: emoji_ids.0, - shortcode: String::from("blobhaj_happy"), - domain: None, - remote_id: String::from("https://local.domain/emoji/blobhaj_happy"), - media_attachment_id: media_attachment_ids.0, - endorsed: false, - created_at: Timestamp::now_utc(), - updated_at: Timestamp::now_utc() - }) - .execute(db_conn); - try_join!(media_fut, emoji_fut) - }.scoped() - }) - .await - .expect("Failed to insert the local emoji"); - db_pool - .with_connection(|db_conn| { - async { - let media_fut = diesel::insert_into(media_attachments::table) - .values(NewMediaAttachment { - id: media_attachment_ids.1, - content_type: "image/jpeg", - account_id: None, - description: None, - blurhash: None, - file_path: None, - remote_url: Some("https://media.example.com/emojis/blobhaj.jpeg"), - }) - .execute(db_conn); - let emoji_fut = diesel::insert_into(custom_emojis::table) - .values(CustomEmoji { - id: emoji_ids.1, - shortcode: String::from("blobhaj_sad"), - domain: Some(String::from("example.com")), - remote_id: String::from("https://example.com/emojis/1"), - media_attachment_id: media_attachment_ids.1, - endorsed: false, - created_at: Timestamp::now_utc(), - updated_at: Timestamp::now_utc(), - }) - .execute(db_conn); - try_join!(media_fut, emoji_fut) - }.scoped() - }) - .await - .expect("Failed to insert the remote emoji"); + with_connection_panicky!(db_pool, |db_conn| { + let media_fut = diesel::insert_into(media_attachments::table) + .values(NewMediaAttachment { + id: media_attachment_ids.0, + content_type: "image/jpeg", + account_id: None, + description: None, + blurhash: None, + file_path: None, + remote_url: None, + }) + .execute(db_conn); + let emoji_fut = diesel::insert_into(custom_emojis::table) + .values(CustomEmoji { + id: emoji_ids.0, + shortcode: String::from("blobhaj_happy"), + domain: None, + remote_id: String::from("https://local.domain/emoji/blobhaj_happy"), + media_attachment_id: media_attachment_ids.0, + endorsed: false, + created_at: Timestamp::now_utc(), + updated_at: Timestamp::now_utc() + }) + .execute(db_conn); + + try_join!(media_fut, emoji_fut) + }) + .expect("Failed to insert the local emoji"); + + with_connection_panicky!(db_pool, |db_conn| { + let media_fut = diesel::insert_into(media_attachments::table) + .values(NewMediaAttachment { + id: media_attachment_ids.1, + content_type: "image/jpeg", + account_id: None, + description: None, + blurhash: None, + file_path: None, + remote_url: Some("https://media.example.com/emojis/blobhaj.jpeg"), + }) + .execute(db_conn); + let emoji_fut = diesel::insert_into(custom_emojis::table) + .values(CustomEmoji { + id: emoji_ids.1, + shortcode: String::from("blobhaj_sad"), + domain: Some(String::from("example.com")), + remote_id: String::from("https://example.com/emojis/1"), + media_attachment_id: media_attachment_ids.1, + endorsed: false, + created_at: Timestamp::now_utc(), + updated_at: Timestamp::now_utc(), + }) + .execute(db_conn); + try_join!(media_fut, emoji_fut) + }) + .expect("Failed to insert the remote emoji"); let post_resolver = PostResolver::builder() .account(account_service) @@ -300,16 +294,14 @@ mod test { assert_eq!(resolved.custom_emojis.len(), 2); let (account_id, _mention_text) = &resolved.mentioned_accounts[0]; - let mentioned_account = db_pool - .with_connection(|db_conn| { - accounts::table - .find(account_id) - .select(Account::as_select()) - .get_result::(db_conn) - .scoped() - }) - .await - .expect("Failed to fetch account"); + let mentioned_account = with_connection_panicky!(db_pool, |db_conn| { + accounts::table + .find(account_id) + .select(Account::as_select()) + .get_result::(db_conn) + .await + }) + .expect("Failed to fetch account"); assert_eq!(mentioned_account.username, "0x0"); assert_eq!(mentioned_account.domain, "corteximplant.com"); diff --git a/kitsune-job-runner/Cargo.toml b/kitsune-job-runner/Cargo.toml index 5fe755fd9..fde6425a2 100644 --- a/kitsune-job-runner/Cargo.toml +++ b/kitsune-job-runner/Cargo.toml @@ -14,6 +14,7 @@ eula = false [dependencies] athena = { path = "../lib/athena" } clap = { version = "4.5.4", features = ["derive", "wrap_help"] } +color-eyre = "0.6.3" just-retry = { path = "../lib/just-retry" } kitsune-config = { path = "../crates/kitsune-config" } kitsune-core = { path = "../crates/kitsune-core" } diff --git a/kitsune-job-runner/src/main.rs b/kitsune-job-runner/src/main.rs index 0bbb907f8..8912bc4af 100644 --- a/kitsune-job-runner/src/main.rs +++ b/kitsune-job-runner/src/main.rs @@ -1,4 +1,5 @@ use clap::Parser; +use color_eyre::eyre; use kitsune_config::Configuration; use kitsune_core::consts::VERSION; use kitsune_federation_filter::FederationFilter; @@ -6,7 +7,6 @@ use kitsune_job_runner::JobDispatcherState; use kitsune_service::{attachment::AttachmentService, prepare}; use kitsune_url::UrlService; use kitsune_wasm_mrf::MrfService; -use miette::IntoDiagnostic; use std::path::PathBuf; #[global_allocator] @@ -22,8 +22,8 @@ struct Args { } #[tokio::main] -async fn main() -> miette::Result<()> { - miette::set_panic_hook(); +async fn main() -> eyre::Result<()> { + color_eyre::install()?; let args = Args::parse(); let config = Configuration::load(args.config).await?; diff --git a/kitsune/Cargo.toml b/kitsune/Cargo.toml index 994fc36c0..4def0ed3f 100644 --- a/kitsune/Cargo.toml +++ b/kitsune/Cargo.toml @@ -125,6 +125,7 @@ kitsune-mastodon = { path = "../crates/kitsune-mastodon", optional = true } # "oidc" feature kitsune-oidc = { path = "../crates/kitsune-oidc", optional = true } +color-eyre = "0.6.3" [build-dependencies] camino = "1.1.6" diff --git a/kitsune/src/error.rs b/kitsune/src/error.rs index 1de24ae55..53452374f 100644 --- a/kitsune/src/error.rs +++ b/kitsune/src/error.rs @@ -4,14 +4,12 @@ use axum::{ extract::multipart::MultipartError, response::{IntoResponse, Response}, }; +use color_eyre::eyre; use diesel_async::pooled_connection::bb8; use http::StatusCode; -use kitsune_core::error::{BoxError, HttpError}; +use kitsune_core::error::HttpError; use kitsune_service::error::{Error as ServiceError, PostError}; -use std::{ - fmt::{Debug, Display}, - str::ParseBoolError, -}; +use std::{fmt::Debug, str::ParseBoolError}; use thiserror::Error; pub type Result = std::result::Result; @@ -41,7 +39,7 @@ pub enum Error { Der(#[from] der::Error), #[error(transparent)] - Fetcher(BoxError), + Fetcher(eyre::Report), #[error(transparent)] Http(#[from] http::Error), @@ -115,18 +113,6 @@ pub enum OAuth2Error { Web(#[from] oxide_auth_axum::WebError), } -impl From> for Error -where - E: Into + Debug + Display, -{ - fn from(value: kitsune_db::PoolError) -> Self { - match value { - kitsune_db::PoolError::Pool(err) => err.into(), - kitsune_db::PoolError::User(err) => err.into(), - } - } -} - impl From for Response { fn from(err: Error) -> Response { err.into_response() diff --git a/kitsune/src/http/extractor/auth.rs b/kitsune/src/http/extractor/auth.rs index 298260ba7..4f3e14e2f 100644 --- a/kitsune/src/http/extractor/auth.rs +++ b/kitsune/src/http/extractor/auth.rs @@ -13,8 +13,8 @@ use http::request::Parts; use kitsune_db::{ model::{account::Account, user::User}, schema::{accounts, oauth2_access_tokens, users}, + with_connection, }; -use scoped_futures::ScopedFutureExt; use time::OffsetDateTime; /// Mastodon-specific auth extractor alias @@ -62,16 +62,13 @@ impl FromRequestParts .filter(oauth2_access_tokens::expires_at.gt(OffsetDateTime::now_utc())); } - let (user, account) = state - .db_pool - .with_connection(|db_conn| { - user_account_query - .select(<(User, Account)>::as_select()) - .get_result(db_conn) - .scoped() - }) - .await - .map_err(Error::from)?; + let (user, account) = with_connection!(state.db_pool, |db_conn| { + user_account_query + .select(<(User, Account)>::as_select()) + .get_result(db_conn) + .await + .map_err(Error::from) + })?; Ok(Self(UserData { account, user })) } diff --git a/kitsune/src/http/extractor/signed_activity.rs b/kitsune/src/http/extractor/signed_activity.rs index 7ff5fa87e..62b2bae42 100644 --- a/kitsune/src/http/extractor/signed_activity.rs +++ b/kitsune/src/http/extractor/signed_activity.rs @@ -15,7 +15,7 @@ use diesel_async::RunQueryDsl; use http::StatusCode; use http_body_util::BodyExt; use kitsune_core::{error::HttpError, traits::fetcher::AccountFetchOptions}; -use kitsune_db::{model::account::Account, schema::accounts, PgPool}; +use kitsune_db::{model::account::Account, schema::accounts, with_connection, PgPool}; use kitsune_type::ap::Activity; use kitsune_wasm_mrf::Outcome; use scoped_futures::ScopedFutureExt; @@ -116,20 +116,18 @@ impl FromRequest for SignedActivity { async fn verify_signature( req: &http::Request<()>, - db_conn: &PgPool, + db_pool: &PgPool, expected_account: Option<&Account>, ) -> Result { let is_valid = http_signatures::cavage::easy::verify(req, |key_id| { async move { - let remote_user: Account = db_conn - .with_connection(|db_conn| { - accounts::table - .filter(accounts::public_key_id.eq(key_id)) - .select(Account::as_select()) - .first(db_conn) - .scoped() - }) - .await?; + let remote_user: Account = with_connection!(db_pool, |db_conn| { + accounts::table + .filter(accounts::public_key_id.eq(key_id)) + .select(Account::as_select()) + .first(db_conn) + .await + })?; // If we have an expected account, which we have in the case of an incoming new activity, // then we do this comparison. diff --git a/kitsune/src/http/graphql/query/instance.rs b/kitsune/src/http/graphql/query/instance.rs index 07a86d6f3..6b31d3fd0 100644 --- a/kitsune/src/http/graphql/query/instance.rs +++ b/kitsune/src/http/graphql/query/instance.rs @@ -1,7 +1,6 @@ use crate::http::graphql::{types::Instance, ContextExt}; use async_graphql::{Context, Object, Result}; use kitsune_core::consts::VERSION; -use std::convert::Into; #[derive(Default)] pub struct InstanceQuery; diff --git a/kitsune/src/http/graphql/types/account.rs b/kitsune/src/http/graphql/types/account.rs index 8f015c1d5..eb9626a46 100644 --- a/kitsune/src/http/graphql/types/account.rs +++ b/kitsune/src/http/graphql/types/account.rs @@ -12,9 +12,9 @@ use kitsune_db::{ account::Account as DbAccount, media_attachment::MediaAttachment as DbMediaAttachment, }, schema::media_attachments, + with_connection, }; use kitsune_service::account::GetPosts; -use scoped_futures::ScopedFutureExt; use speedy_uuid::Uuid; use time::OffsetDateTime; @@ -41,20 +41,15 @@ impl Account { let db_pool = &ctx.state().db_pool; if let Some(avatar_id) = self.avatar_id { - db_pool - .with_connection(|db_conn| { - async move { - media_attachments::table - .find(avatar_id) - .get_result::(db_conn) - .await - .optional() - .map(|attachment| attachment.map(Into::into)) - } - .scoped() - }) - .await - .map_err(Into::into) + with_connection!(db_pool, |db_conn| { + media_attachments::table + .find(avatar_id) + .get_result::(db_conn) + .await + .optional() + .map(|attachment| attachment.map(Into::into)) + }) + .map_err(Into::into) } else { Ok(None) } @@ -64,20 +59,15 @@ impl Account { let db_pool = &ctx.state().db_pool; if let Some(header_id) = self.header_id { - db_pool - .with_connection(|db_conn| { - async move { - media_attachments::table - .find(header_id) - .get_result::(db_conn) - .await - .optional() - .map(|attachment| attachment.map(Into::into)) - } - .scoped() - }) - .await - .map_err(Into::into) + with_connection!(db_pool, |db_conn| { + media_attachments::table + .find(header_id) + .get_result::(db_conn) + .await + .optional() + .map(|attachment| attachment.map(Into::into)) + }) + .map_err(Into::into) } else { Ok(None) } diff --git a/kitsune/src/http/graphql/types/post.rs b/kitsune/src/http/graphql/types/post.rs index 94a9dd343..d8a18d6f1 100644 --- a/kitsune/src/http/graphql/types/post.rs +++ b/kitsune/src/http/graphql/types/post.rs @@ -7,8 +7,8 @@ use futures_util::TryStreamExt; use kitsune_db::{ model::{media_attachment::MediaAttachment as DbMediaAttachment, post::Post as DbPost}, schema::{media_attachments, posts_media_attachments}, + with_connection, }; -use scoped_futures::ScopedFutureExt; use speedy_uuid::Uuid; use time::OffsetDateTime; @@ -44,22 +44,17 @@ impl Post { pub async fn attachments(&self, ctx: &Context<'_>) -> Result> { let db_pool = &ctx.state().db_pool; - let attachments = db_pool - .with_connection(|db_conn| { - async move { - media_attachments::table - .inner_join(posts_media_attachments::table) - .filter(posts_media_attachments::post_id.eq(self.id)) - .select(DbMediaAttachment::as_select()) - .load_stream(db_conn) - .await? - .map_ok(Into::into) - .try_collect() - .await - } - .scoped() - }) - .await?; + let attachments = with_connection!(db_pool, |db_conn| { + media_attachments::table + .inner_join(posts_media_attachments::table) + .filter(posts_media_attachments::post_id.eq(self.id)) + .select(DbMediaAttachment::as_select()) + .load_stream(db_conn) + .await? + .map_ok(Into::into) + .try_collect() + .await + })?; Ok(attachments) } diff --git a/kitsune/src/http/graphql/types/user.rs b/kitsune/src/http/graphql/types/user.rs index a7952fe25..9dc763ac0 100644 --- a/kitsune/src/http/graphql/types/user.rs +++ b/kitsune/src/http/graphql/types/user.rs @@ -6,8 +6,8 @@ use diesel_async::RunQueryDsl; use kitsune_db::{ model::{account::Account as DbAccount, user::User as DbUser}, schema::{accounts, users}, + with_connection, }; -use scoped_futures::ScopedFutureExt; use speedy_uuid::Uuid; use time::OffsetDateTime; @@ -27,21 +27,16 @@ pub struct User { impl User { pub async fn account(&self, ctx: &Context<'_>) -> Result { let db_pool = &ctx.state().db_pool; - db_pool - .with_connection(|db_conn| { - async move { - users::table - .find(self.id) - .inner_join(accounts::table) - .select(DbAccount::as_select()) - .get_result::(db_conn) - .await - .map(Into::into) - } - .scoped() - }) - .await - .map_err(Into::into) + with_connection!(db_pool, |db_conn| { + users::table + .find(self.id) + .inner_join(accounts::table) + .select(DbAccount::as_select()) + .get_result::(db_conn) + .await + .map(Into::into) + }) + .map_err(Into::into) } } diff --git a/kitsune/src/http/handler/mastodon/api/v1/accounts/relationships.rs b/kitsune/src/http/handler/mastodon/api/v1/accounts/relationships.rs index 4c0625e2d..b0f4ac7bb 100644 --- a/kitsune/src/http/handler/mastodon/api/v1/accounts/relationships.rs +++ b/kitsune/src/http/handler/mastodon/api/v1/accounts/relationships.rs @@ -7,10 +7,9 @@ use axum_extra::extract::Query; use diesel::{ExpressionMethods, QueryDsl, SelectableHelper}; use diesel_async::RunQueryDsl; use futures_util::StreamExt; -use kitsune_db::{model::account::Account, schema::accounts, PgPool}; +use kitsune_db::{model::account::Account, schema::accounts, with_connection, PgPool}; use kitsune_mastodon::MastodonMapper; use kitsune_type::mastodon::relationship::Relationship; -use scoped_futures::ScopedFutureExt; use serde::Deserialize; use speedy_uuid::Uuid; use utoipa::IntoParams; @@ -39,15 +38,13 @@ pub async fn get( State(mastodon_mapper): State, Query(query): Query, ) -> Result>> { - let mut account_stream = db_pool - .with_connection(|db_conn| { - accounts::table - .filter(accounts::id.eq_any(&query.id)) - .select(Account::as_select()) - .load_stream::(db_conn) - .scoped() - }) - .await?; + let mut account_stream = with_connection!(db_pool, |db_conn| { + accounts::table + .filter(accounts::id.eq_any(&query.id)) + .select(Account::as_select()) + .load_stream::(db_conn) + .await + })?; let mut relationships = Vec::with_capacity(query.id.len()); while let Some(account) = account_stream.next().await.transpose()? { diff --git a/kitsune/src/http/handler/nodeinfo/two_one.rs b/kitsune/src/http/handler/nodeinfo/two_one.rs index 5be023b7c..9983c2450 100644 --- a/kitsune/src/http/handler/nodeinfo/two_one.rs +++ b/kitsune/src/http/handler/nodeinfo/two_one.rs @@ -5,14 +5,13 @@ use diesel_async::RunQueryDsl; use kitsune_core::consts::VERSION; use kitsune_db::{ schema::{posts, users}, - PgPool, + with_connection, PgPool, }; use kitsune_service::user::UserService; use kitsune_type::nodeinfo::two_one::{ Protocol, Services, Software, TwoOne, Usage, UsageUsers, Version, }; use kitsune_util::try_join; -use scoped_futures::ScopedFutureExt; use simd_json::{OwnedValue, ValueBuilder}; #[debug_handler(state = crate::state::Zustand)] @@ -27,20 +26,15 @@ async fn get( State(db_pool): State, State(user_service): State, ) -> Result> { - let (total, local_posts) = db_pool - .with_connection(|db_conn| { - async move { - let total_fut = users::table.count().get_result::(db_conn); - let local_posts_fut = posts::table - .filter(posts::is_local.eq(true)) - .count() - .get_result::(db_conn); + let (total, local_posts) = with_connection!(db_pool, |db_conn| { + let total_fut = users::table.count().get_result::(db_conn); + let local_posts_fut = posts::table + .filter(posts::is_local.eq(true)) + .count() + .get_result::(db_conn); - try_join!(total_fut, local_posts_fut) - } - .scoped() - }) - .await?; + try_join!(total_fut, local_posts_fut) + })?; Ok(Json(TwoOne { version: Version::TwoOne, diff --git a/kitsune/src/http/handler/oauth/authorize.rs b/kitsune/src/http/handler/oauth/authorize.rs index 5df998605..78ae3b376 100644 --- a/kitsune/src/http/handler/oauth/authorize.rs +++ b/kitsune/src/http/handler/oauth/authorize.rs @@ -19,10 +19,10 @@ use axum_flash::{Flash, IncomingFlashes}; use cursiv::CsrfHandle; use diesel::{ExpressionMethods, OptionalExtension, QueryDsl}; use diesel_async::RunQueryDsl; +use kitsune_db::with_connection; use kitsune_db::{model::user::User, schema::users, PgPool}; use oxide_auth_async::endpoint::authorization::AuthorizationFlow; use oxide_auth_axum::{OAuthRequest, OAuthResponse}; -use scoped_futures::ScopedFutureExt; use serde::Deserialize; use speedy_uuid::Uuid; @@ -69,15 +69,13 @@ pub async fn get( ) -> Result> { #[cfg(feature = "oidc")] if let Some(oidc_service) = oidc_service { - let application = db_pool - .with_connection(|db_conn| { - oauth2_applications::table - .find(query.client_id) - .filter(oauth2_applications::redirect_uri.eq(query.redirect_uri)) - .get_result::(db_conn) - .scoped() - }) - .await?; + let application = with_connection!(db_pool, |db_conn| { + oauth2_applications::table + .find(query.client_id) + .filter(oauth2_applications::redirect_uri.eq(query.redirect_uri)) + .get_result::(db_conn) + .await + })?; let auth_url = oidc_service .authorisation_url(application.id, query.scope, query.state) @@ -88,10 +86,9 @@ pub async fn get( let authenticated_user = if let Some(user_id) = cookies.get("user_id") { let id = user_id.value().parse::()?; - - db_pool - .with_connection(|db_conn| users::table.find(id).get_result(db_conn).scoped()) - .await? + with_connection!(db_pool, |db_conn| { + users::table.find(id).get_result(db_conn).await + })? } else { return Ok(Either3::E2(LoginPage { flash_messages })); }; @@ -123,18 +120,13 @@ pub async fn post( original_url.path() }; - let user = db_pool - .with_connection(|db_conn| { - async move { - users::table - .filter(users::username.eq(form.username)) - .first::(db_conn) - .await - .optional() - } - .scoped() - }) - .await?; + let user = with_connection!(db_pool, |db_conn| { + users::table + .filter(users::username.eq(form.username)) + .first::(db_conn) + .await + .optional() + })?; let Some(user) = user else { return Ok(Either::E2(( diff --git a/kitsune/src/http/handler/oidc/callback.rs b/kitsune/src/http/handler/oidc/callback.rs index e4ab4665a..0f92faaa7 100644 --- a/kitsune/src/http/handler/oidc/callback.rs +++ b/kitsune/src/http/handler/oidc/callback.rs @@ -11,11 +11,10 @@ use diesel_async::RunQueryDsl; use kitsune_core::error::HttpError; use kitsune_db::{ schema::{oauth2_applications, users}, - PgPool, + with_connection, PgPool, }; use kitsune_oidc::OidcService; use kitsune_service::user::{Register, UserService}; -use scoped_futures::ScopedFutureExt; use serde::Deserialize; #[derive(Debug, Deserialize)] @@ -36,18 +35,13 @@ pub async fn get( }; let user_info = oidc_service.get_user_info(query.state, query.code).await?; - let user = db_pool - .with_connection(|db_conn| { - async { - users::table - .filter(users::oidc_id.eq(&user_info.subject)) - .get_result(db_conn) - .await - .optional() - } - .scoped() - }) - .await?; + let user = with_connection!(db_pool, |db_conn| { + users::table + .filter(users::oidc_id.eq(&user_info.subject)) + .get_result(db_conn) + .await + .optional() + })?; let user = if let Some(user) = user { user @@ -62,14 +56,12 @@ pub async fn get( user_service.register(register).await? }; - let application = db_pool - .with_connection(|db_conn| { - oauth2_applications::table - .find(user_info.oauth2.application_id) - .get_result(db_conn) - .scoped() - }) - .await?; + let application = with_connection!(db_pool, |db_conn| { + oauth2_applications::table + .find(user_info.oauth2.application_id) + .get_result(db_conn) + .await + })?; let authorisation_code = AuthorisationCode::builder() .application(application) diff --git a/kitsune/src/http/handler/users/followers.rs b/kitsune/src/http/handler/users/followers.rs index 14c113cd8..a1732373f 100644 --- a/kitsune/src/http/handler/users/followers.rs +++ b/kitsune/src/http/handler/users/followers.rs @@ -2,13 +2,15 @@ use crate::{error::Result, http::responder::ActivityPubJson, state::Zustand}; use axum::extract::{OriginalUri, Path, State}; use diesel::{BoolExpressionMethods, ExpressionMethods, JoinOnDsl, QueryDsl}; use diesel_async::RunQueryDsl; -use kitsune_db::schema::{accounts, accounts_follows}; +use kitsune_db::{ + schema::{accounts, accounts_follows}, + with_connection, +}; use kitsune_type::ap::{ ap_context, collection::{Collection, CollectionType}, }; use kitsune_url::UrlService; -use scoped_futures::ScopedFutureExt; use speedy_uuid::Uuid; pub async fn get( @@ -17,22 +19,19 @@ pub async fn get( OriginalUri(original_uri): OriginalUri, Path(account_id): Path, ) -> Result> { - let follower_count = state - .db_pool - .with_connection(|db_conn| { - accounts_follows::table - .inner_join( - accounts::table.on(accounts_follows::account_id - .eq(accounts::id) - .and(accounts_follows::approved_at.is_not_null()) - .and(accounts::id.eq(account_id)) - .and(accounts::local.eq(true))), - ) - .count() - .get_result::(db_conn) - .scoped() - }) - .await?; + let follower_count = with_connection!(state.db_pool, |db_conn| { + accounts_follows::table + .inner_join( + accounts::table.on(accounts_follows::account_id + .eq(accounts::id) + .and(accounts_follows::approved_at.is_not_null()) + .and(accounts::id.eq(account_id)) + .and(accounts::local.eq(true))), + ) + .count() + .get_result::(db_conn) + .await + })?; let mut id = url_service.base_url(); id.push_str(original_uri.path()); diff --git a/kitsune/src/http/handler/users/following.rs b/kitsune/src/http/handler/users/following.rs index 30747a23d..478c01390 100644 --- a/kitsune/src/http/handler/users/following.rs +++ b/kitsune/src/http/handler/users/following.rs @@ -2,13 +2,15 @@ use crate::{error::Result, http::responder::ActivityPubJson, state::Zustand}; use axum::extract::{OriginalUri, Path, State}; use diesel::{BoolExpressionMethods, ExpressionMethods, JoinOnDsl, QueryDsl}; use diesel_async::RunQueryDsl; -use kitsune_db::schema::{accounts, accounts_follows}; +use kitsune_db::{ + schema::{accounts, accounts_follows}, + with_connection, +}; use kitsune_type::ap::{ ap_context, collection::{Collection, CollectionType}, }; use kitsune_url::UrlService; -use scoped_futures::ScopedFutureExt; use speedy_uuid::Uuid; pub async fn get( @@ -17,22 +19,19 @@ pub async fn get( OriginalUri(original_uri): OriginalUri, Path(account_id): Path, ) -> Result> { - let following_count = state - .db_pool - .with_connection(|db_conn| { - accounts_follows::table - .inner_join( - accounts::table.on(accounts_follows::follower_id - .eq(accounts::id) - .and(accounts_follows::approved_at.is_not_null()) - .and(accounts::id.eq(account_id)) - .and(accounts::local.eq(true))), - ) - .count() - .get_result::(db_conn) - .scoped() - }) - .await?; + let following_count = with_connection!(state.db_pool, |db_conn| { + accounts_follows::table + .inner_join( + accounts::table.on(accounts_follows::follower_id + .eq(accounts::id) + .and(accounts_follows::approved_at.is_not_null()) + .and(accounts::id.eq(account_id)) + .and(accounts::local.eq(true))), + ) + .count() + .get_result::(db_conn) + .await + })?; let id = format!("{}{}", url_service.base_url(), original_uri.path()); Ok(ActivityPubJson(Collection { diff --git a/kitsune/src/http/handler/users/inbox.rs b/kitsune/src/http/handler/users/inbox.rs index a05f52293..3e383080a 100644 --- a/kitsune/src/http/handler/users/inbox.rs +++ b/kitsune/src/http/handler/users/inbox.rs @@ -25,28 +25,23 @@ use kitsune_db::{ }, post_permission_check::{PermissionCheck, PostPermissionCheckExt}, schema::{accounts_follows, accounts_preferences, notifications, posts, posts_favourites}, + with_connection, }; use kitsune_federation_filter::FederationFilter; use kitsune_jobs::deliver::accept::DeliverAccept; use kitsune_service::job::Enqueue; use kitsune_type::ap::{Activity, ActivityType}; use kitsune_util::try_join; -use scoped_futures::ScopedFutureExt; use speedy_uuid::Uuid; use std::ops::Not; async fn accept_activity(state: &Zustand, activity: Activity) -> Result<()> { - state - .db_pool - .with_connection(|db_conn| { - diesel::update( - accounts_follows::table.filter(accounts_follows::url.eq(activity.object())), - ) + with_connection!(state.db_pool, |db_conn| { + diesel::update(accounts_follows::table.filter(accounts_follows::url.eq(activity.object()))) .set(accounts_follows::approved_at.eq(Timestamp::now_utc())) .execute(db_conn) - .scoped() - }) - .await?; + .await + })?; Ok(()) } @@ -61,30 +56,27 @@ async fn announce_activity(state: &Zustand, author: Account, activity: Activity) return Err(HttpError::BadRequest.into()); }; - state - .db_pool - .with_connection(|db_conn| { - diesel::insert_into(posts::table) - .values(NewPost { - id: Uuid::now_v7(), - account_id: author.id, - in_reply_to_id: None, - reposted_post_id: Some(reposted_post.id), - is_sensitive: false, - subject: None, - content: "", - content_source: "", - content_lang: kitsune_language::Language::Eng.into(), - link_preview_url: None, - visibility: reposted_post.visibility, - is_local: false, - url: activity.id.as_str(), - created_at: None, - }) - .execute(db_conn) - .scoped() - }) - .await?; + with_connection!(state.db_pool, |db_conn| { + diesel::insert_into(posts::table) + .values(NewPost { + id: Uuid::now_v7(), + account_id: author.id, + in_reply_to_id: None, + reposted_post_id: Some(reposted_post.id), + is_sensitive: false, + subject: None, + content: "", + content_source: "", + content_lang: kitsune_language::Language::Eng.into(), + link_preview_url: None, + visibility: reposted_post.visibility, + is_local: false, + url: activity.id.as_str(), + created_at: None, + }) + .execute(db_conn) + .await + })?; Ok(()) } @@ -117,26 +109,20 @@ async fn create_activity(state: &Zustand, author: Account, activity: Activity) - } async fn delete_activity(state: &Zustand, author: Account, activity: Activity) -> Result<()> { - let post_id = state - .db_pool - .with_connection(|db_conn| { - async move { - let post_id = posts::table - .filter(posts::account_id.eq(author.id)) - .filter(posts::url.eq(activity.object())) - .select(posts::id) - .get_result(db_conn) - .await?; - - diesel::delete(posts::table.find(post_id)) - .execute(db_conn) - .await?; + let post_id = with_connection!(state.db_pool, |db_conn| { + let post_id = posts::table + .filter(posts::account_id.eq(author.id)) + .filter(posts::url.eq(activity.object())) + .select(posts::id) + .get_result(db_conn) + .await?; - Ok::<_, Error>(post_id) - } - .scoped() - }) - .await?; + diesel::delete(posts::table.find(post_id)) + .execute(db_conn) + .await?; + + Ok::<_, Error>(post_id) + })?; state .event_emitter @@ -163,36 +149,31 @@ async fn follow_activity(state: &Zustand, author: Account, activity: Activity) - let approved_at = followed_user.locked.not().then(Timestamp::now_utc); - let follow_id = state - .db_pool - .with_connection(|db_conn| { - diesel::insert_into(accounts_follows::table) - .values(NewFollow { - id: Uuid::now_v7(), - account_id: followed_user.id, - follower_id: author.id, - approved_at, - url: activity.id.as_str(), - notify: false, - created_at: Some(activity.published), - }) - .returning(accounts_follows::id) - .get_result(db_conn) - .scoped() - }) - .await?; + let follow_id = with_connection!(state.db_pool, |db_conn| { + diesel::insert_into(accounts_follows::table) + .values(NewFollow { + id: Uuid::now_v7(), + account_id: followed_user.id, + follower_id: author.id, + approved_at, + url: activity.id.as_str(), + notify: false, + created_at: Some(activity.published), + }) + .returning(accounts_follows::id) + .get_result(db_conn) + .await + })?; if followed_user.local { - let preferences = state - .db_pool - .with_connection(|mut db_conn| { - accounts_preferences::table - .find(followed_user.id) - .select(Preferences::as_select()) - .get_result(&mut db_conn) - .scoped() - }) - .await?; + let preferences = with_connection!(state.db_pool, |db_conn| { + accounts_preferences::table + .find(followed_user.id) + .select(Preferences::as_select()) + .get_result(db_conn) + .await + })?; + if (preferences.notify_on_follow && !followed_user.locked) || (preferences.notify_on_follow_request && followed_user.locked) { @@ -205,16 +186,14 @@ async fn follow_activity(state: &Zustand, author: Account, activity: Activity) - .receiving_account_id(followed_user.id) .follow(author.id) }; - state - .db_pool - .with_connection(|mut db_conn| { - diesel::insert_into(notifications::table) - .values(notification) - .on_conflict_do_nothing() - .execute(&mut db_conn) - .scoped() - }) - .await?; + + with_connection!(state.db_pool, |db_conn| { + diesel::insert_into(notifications::table) + .values(notification) + .on_conflict_do_nothing() + .execute(db_conn) + .await + })?; } state .service @@ -231,94 +210,79 @@ async fn like_activity(state: &Zustand, author: Account, activity: Activity) -> .fetching_account_id(Some(author.id)) .build(); - state - .db_pool - .with_connection(|db_conn| { - async move { - let post = posts::table - .filter(posts::url.eq(activity.object())) - .add_post_permission_check(permission_check) - .select(Post::as_select()) - .get_result::(db_conn) - .await?; - - diesel::insert_into(posts_favourites::table) - .values(NewFavourite { - id: Uuid::now_v7(), - account_id: author.id, - post_id: post.id, - url: activity.id, - created_at: Some(Timestamp::now_utc()), - }) - .execute(db_conn) - .await?; + with_connection!(state.db_pool, |db_conn| { + let post = posts::table + .filter(posts::url.eq(activity.object())) + .add_post_permission_check(permission_check) + .select(Post::as_select()) + .get_result::(db_conn) + .await?; - Ok::<_, Error>(()) - } - .scoped() - }) - .await?; + diesel::insert_into(posts_favourites::table) + .values(NewFavourite { + id: Uuid::now_v7(), + account_id: author.id, + post_id: post.id, + url: activity.id, + created_at: Some(Timestamp::now_utc()), + }) + .execute(db_conn) + .await?; + + Ok::<_, Error>(()) + })?; Ok(()) } async fn reject_activity(state: &Zustand, author: Account, activity: Activity) -> Result<()> { - state - .db_pool - .with_connection(|db_conn| { - diesel::delete( - accounts_follows::table.filter( - accounts_follows::account_id - .eq(author.id) - .and(accounts_follows::url.eq(activity.object())), - ), - ) - .execute(db_conn) - .scoped() - }) - .await?; + with_connection!(state.db_pool, |db_conn| { + diesel::delete( + accounts_follows::table.filter( + accounts_follows::account_id + .eq(author.id) + .and(accounts_follows::url.eq(activity.object())), + ), + ) + .execute(db_conn) + .await + })?; Ok(()) } async fn undo_activity(state: &Zustand, author: Account, activity: Activity) -> Result<()> { - state - .db_pool - .with_connection(|db_conn| { - async move { - // An undo activity can apply for likes and follows and announces - let favourite_delete_fut = diesel::delete( - posts_favourites::table.filter( - posts_favourites::account_id - .eq(author.id) - .and(posts_favourites::url.eq(activity.object())), - ), - ) - .execute(db_conn); - - let follow_delete_fut = diesel::delete( - accounts_follows::table.filter( - accounts_follows::follower_id - .eq(author.id) - .and(accounts_follows::url.eq(activity.object())), - ), - ) - .execute(db_conn); - - let repost_delete_fut = diesel::delete( - posts::table.filter( - posts::url - .eq(activity.object()) - .and(posts::account_id.eq(author.id)), - ), - ) - .execute(db_conn); - - try_join!(favourite_delete_fut, follow_delete_fut, repost_delete_fut) - } - .scoped() - }) - .await?; + with_connection!(state.db_pool, |db_conn| { + // An undo activity can apply for likes and follows and announces + let favourite_delete_fut = diesel::delete( + posts_favourites::table.filter( + posts_favourites::account_id + .eq(author.id) + .and(posts_favourites::url.eq(activity.object())), + ), + ) + .execute(db_conn); + + let follow_delete_fut = diesel::delete( + accounts_follows::table.filter( + accounts_follows::follower_id + .eq(author.id) + .and(accounts_follows::url.eq(activity.object())), + ), + ) + .execute(db_conn); + + let repost_delete_fut = diesel::delete( + posts::table.filter( + posts::url + .eq(activity.object()) + .and(posts::account_id.eq(author.id)), + ), + ) + .execute(db_conn); + + try_join!(favourite_delete_fut, follow_delete_fut, repost_delete_fut) + })?; Ok(()) } diff --git a/kitsune/src/http/handler/users/outbox.rs b/kitsune/src/http/handler/users/outbox.rs index 304a75ea8..ab8a713af 100644 --- a/kitsune/src/http/handler/users/outbox.rs +++ b/kitsune/src/http/handler/users/outbox.rs @@ -8,6 +8,7 @@ use kitsune_db::{ model::{account::Account, post::Post}, post_permission_check::{PermissionCheck, PostPermissionCheckExt}, schema::accounts, + with_connection, }; use kitsune_service::account::GetPosts; use kitsune_type::ap::{ @@ -16,7 +17,6 @@ use kitsune_type::ap::{ Activity, }; use kitsune_url::UrlService; -use scoped_futures::ScopedFutureExt; use serde::{Deserialize, Serialize}; use speedy_uuid::Uuid; @@ -37,19 +37,16 @@ pub async fn get( Path(account_id): Path, Query(query): Query, ) -> Result>, ActivityPubJson>> { - let account = state - .db_pool - .with_connection(|db_conn| { - use diesel_async::RunQueryDsl; + let account = with_connection!(state.db_pool, |db_conn| { + use diesel_async::RunQueryDsl; - accounts::table - .find(account_id) - .filter(accounts::local.eq(true)) - .select(Account::as_select()) - .get_result::(db_conn) - .scoped() - }) - .await?; + accounts::table + .find(account_id) + .filter(accounts::local.eq(true)) + .select(Account::as_select()) + .get_result::(db_conn) + .await + })?; let base_url = format!("{}{}", url_service.base_url(), original_uri.path()); @@ -93,18 +90,15 @@ pub async fn get( ordered_items, }))) } else { - let public_post_count = state - .db_pool - .with_connection(|db_conn| { - use diesel_async::RunQueryDsl; + let public_post_count = with_connection!(state.db_pool, |db_conn| { + use diesel_async::RunQueryDsl; - Post::belonging_to(&account) - .add_post_permission_check(PermissionCheck::default()) - .count() - .get_result::(db_conn) - .scoped() - }) - .await?; + Post::belonging_to(&account) + .add_post_permission_check(PermissionCheck::default()) + .count() + .get_result::(db_conn) + .await + })?; let first = format!("{base_url}?page=true"); let last = format!("{base_url}?page=true&min_id={}", Uuid::nil()); diff --git a/kitsune/src/http/handler/well_known/webfinger.rs b/kitsune/src/http/handler/well_known/webfinger.rs index 9362a3d70..01a5f0049 100644 --- a/kitsune/src/http/handler/well_known/webfinger.rs +++ b/kitsune/src/http/handler/well_known/webfinger.rs @@ -83,7 +83,7 @@ mod tests { use kitsune_db::{ model::account::{ActorType, NewAccount}, schema::accounts, - PgPool, + with_connection_panicky, PgPool, }; use kitsune_federation_filter::FederationFilter; use kitsune_http_client::Client; @@ -169,12 +169,8 @@ mod tests { async fn basic() { database_test(|db_pool| { redis_test(|redis_pool| async move { - let account_id = db_pool - .with_connection(|db_conn| { - async move { Ok::<_, eyre::Report>(prepare_db(db_conn).await) }.scoped() - }) - .await - .unwrap(); + let account_id = + with_connection_panicky!(db_pool, |db_conn| { prepare_db(db_conn).await }); let account_url = format!("https://example.com/users/{account_id}"); let url_service = UrlService::builder() @@ -236,16 +232,9 @@ mod tests { async fn custom_domain() { database_test(|db_pool| { redis_test(|redis_pool| async move { - db_pool - .with_connection(|db_conn| { - async move { - prepare_db(db_conn).await; - Ok::<_, eyre::Report>(()) - } - .scoped() - }) - .await - .unwrap(); + with_connection_panicky!(db_pool, |db_conn| { + prepare_db(db_conn).await; + }); let url_service = UrlService::builder() .scheme("https") diff --git a/kitsune/src/http/mod.rs b/kitsune/src/http/mod.rs index c6b1da082..7682d256a 100644 --- a/kitsune/src/http/mod.rs +++ b/kitsune/src/http/mod.rs @@ -6,9 +6,9 @@ use self::{ }; use crate::state::Zustand; use axum::{extract::DefaultBodyLimit, Router}; +use color_eyre::eyre::{self, Context}; use cursiv::CsrfLayer; use kitsune_config::server; -use miette::{Context, IntoDiagnostic}; use std::time::Duration; use tokio::net::TcpListener; use tower_http::{ @@ -37,7 +37,7 @@ pub mod extractor; pub fn create_router( state: Zustand, server_config: &server::Configuration, -) -> miette::Result { +) -> eyre::Result { let frontend_dir = &server_config.frontend_dir; let frontend_index_path = { let mut tmp = frontend_dir.to_string(); @@ -111,7 +111,7 @@ pub async fn run( state: Zustand, server_config: server::Configuration, shutdown_signal: crate::signal::Receiver, -) -> miette::Result<()> { +) -> eyre::Result<()> { let router = create_router(state, &server_config)?; let listener = TcpListener::bind(("0.0.0.0", server_config.port)).await?; diff --git a/kitsune/src/lib.rs b/kitsune/src/lib.rs index 55b180744..1da99707d 100644 --- a/kitsune/src/lib.rs +++ b/kitsune/src/lib.rs @@ -16,6 +16,7 @@ use self::{ state::{EventEmitter, Service, SessionConfig, Zustand, ZustandInner}, }; use athena::JobQueue; +use color_eyre::eyre; use kitsune_config::Configuration; use kitsune_db::PgPool; use kitsune_email::MailingService; @@ -50,7 +51,7 @@ pub async fn initialise_state( config: &Configuration, db_pool: PgPool, job_queue: JobQueue, -) -> miette::Result { +) -> eyre::Result { let messaging_hub = prepare::messaging(&config.messaging).await?; let status_event_emitter = messaging_hub.emitter("event.status".into()); diff --git a/kitsune/src/oauth2/authorizer.rs b/kitsune/src/oauth2/authorizer.rs index 26c50858b..f377760d5 100644 --- a/kitsune/src/oauth2/authorizer.rs +++ b/kitsune/src/oauth2/authorizer.rs @@ -5,12 +5,11 @@ use diesel_async::RunQueryDsl; use kitsune_db::{ model::oauth2, schema::{oauth2_applications, oauth2_authorization_codes}, - PgPool, + with_connection, PgPool, }; use kitsune_util::generate_secret; use oxide_auth::primitives::grant::{Extensions, Grant}; use oxide_auth_async::primitives::Authorizer; -use scoped_futures::ScopedFutureExt; #[derive(Clone)] pub struct OAuthAuthorizer { @@ -26,40 +25,32 @@ impl Authorizer for OAuthAuthorizer { let secret = generate_secret(); let expires_at = chrono_to_timestamp(grant.until); - self.db_pool - .with_connection(|db_conn| { - diesel::insert_into(oauth2_authorization_codes::table) - .values(oauth2::NewAuthorizationCode { - code: secret.as_str(), - application_id, - user_id, - scopes: scopes.as_str(), - expires_at, - }) - .returning(oauth2_authorization_codes::code) - .get_result(db_conn) - .scoped() - }) - .await - .map_err(|_| ()) + with_connection!(self.db_pool, |db_conn| { + diesel::insert_into(oauth2_authorization_codes::table) + .values(oauth2::NewAuthorizationCode { + code: secret.as_str(), + application_id, + user_id, + scopes: scopes.as_str(), + expires_at, + }) + .returning(oauth2_authorization_codes::code) + .get_result(db_conn) + .await + }) + .map_err(|_| ()) } async fn extract(&mut self, authorization_code: &str) -> Result, ()> { - let oauth_data = self - .db_pool - .with_connection(|db_conn| { - async move { - oauth2_authorization_codes::table - .find(authorization_code) - .inner_join(oauth2_applications::table) - .first::<(oauth2::AuthorizationCode, oauth2::Application)>(db_conn) - .await - .optional() - } - .scoped() - }) - .await - .map_err(|_| ())?; + let oauth_data = with_connection!(self.db_pool, |db_conn| { + oauth2_authorization_codes::table + .find(authorization_code) + .inner_join(oauth2_applications::table) + .first::<(oauth2::AuthorizationCode, oauth2::Application)>(db_conn) + .await + .optional() + }) + .map_err(|_| ())?; let oauth_data = oauth_data.map(|(code, app)| { let scope = app.scopes.parse().unwrap(); diff --git a/kitsune/src/oauth2/issuer.rs b/kitsune/src/oauth2/issuer.rs index 44ee138f6..02eeda4d8 100644 --- a/kitsune/src/oauth2/issuer.rs +++ b/kitsune/src/oauth2/issuer.rs @@ -6,7 +6,7 @@ use diesel_async::RunQueryDsl; use kitsune_db::{ model::oauth2, schema::{oauth2_access_tokens, oauth2_applications, oauth2_refresh_tokens}, - PgPool, + with_connection, with_transaction, PgPool, }; use kitsune_util::generate_secret; use oxide_auth::primitives::{ @@ -15,7 +15,6 @@ use oxide_auth::primitives::{ prelude::IssuedToken, }; use oxide_auth_async::primitives::Issuer; -use scoped_futures::ScopedFutureExt; #[derive(Clone)] pub struct OAuthIssuer { @@ -30,38 +29,32 @@ impl Issuer for OAuthIssuer { let scopes = grant.scope.to_string(); let expires_at = chrono_to_timestamp(grant.until); - let (access_token, refresh_token) = self - .db_pool - .with_transaction(|tx| { - async move { - let access_token = diesel::insert_into(oauth2_access_tokens::table) - .values(oauth2::NewAccessToken { - token: generate_secret().as_str(), - user_id: Some(user_id), - application_id: Some(application_id), - scopes: scopes.as_str(), - expires_at, - }) - .returning(oauth2::AccessToken::as_returning()) - .get_result::(tx) - .await?; - - let refresh_token = diesel::insert_into(oauth2_refresh_tokens::table) - .values(oauth2::NewRefreshToken { - token: generate_secret().as_str(), - access_token: access_token.token.as_str(), - application_id, - }) - .returning(oauth2::RefreshToken::as_returning()) - .get_result::(tx) - .await?; - - Ok::<_, Error>((access_token, refresh_token)) - } - .scoped() - }) - .await - .map_err(|_| ())?; + let (access_token, refresh_token) = with_transaction!(self.db_pool, |tx| { + let access_token = diesel::insert_into(oauth2_access_tokens::table) + .values(oauth2::NewAccessToken { + token: generate_secret().as_str(), + user_id: Some(user_id), + application_id: Some(application_id), + scopes: scopes.as_str(), + expires_at, + }) + .returning(oauth2::AccessToken::as_returning()) + .get_result::(tx) + .await?; + + let refresh_token = diesel::insert_into(oauth2_refresh_tokens::table) + .values(oauth2::NewRefreshToken { + token: generate_secret().as_str(), + access_token: access_token.token.as_str(), + application_id, + }) + .returning(oauth2::RefreshToken::as_returning()) + .get_result::(tx) + .await?; + + Ok::<_, Error>((access_token, refresh_token)) + }) + .map_err(|_| ())?; Ok(IssuedToken { token: access_token.token, @@ -72,49 +65,38 @@ impl Issuer for OAuthIssuer { } async fn refresh(&mut self, refresh_token: &str, grant: Grant) -> Result { - let (refresh_token, access_token) = self - .db_pool - .with_connection(|db_conn| { - oauth2_refresh_tokens::table - .find(refresh_token) - .inner_join(oauth2_access_tokens::table) - .select(<(oauth2::RefreshToken, oauth2::AccessToken)>::as_select()) - .get_result::<(oauth2::RefreshToken, oauth2::AccessToken)>(db_conn) - .scoped() - }) - .await - .map_err(|_| ())?; - - let (access_token, refresh_token) = self - .db_pool - .with_transaction(|tx| { - async move { - let new_access_token = diesel::insert_into(oauth2_access_tokens::table) - .values(oauth2::NewAccessToken { - user_id: access_token.user_id, - token: generate_secret().as_str(), - application_id: access_token.application_id, - scopes: access_token.scopes.as_str(), - expires_at: chrono_to_timestamp(grant.until), - }) - .get_result::(tx) - .await?; - - let refresh_token = diesel::update(&refresh_token) - .set( - oauth2_refresh_tokens::access_token.eq(new_access_token.token.as_str()), - ) - .get_result::(tx) - .await?; - - diesel::delete(&access_token).execute(tx).await?; - - Ok::<_, Error>((new_access_token, refresh_token)) - } - .scoped() - }) - .await - .map_err(|_| ())?; + let (refresh_token, access_token) = with_connection!(self.db_pool, |db_conn| { + oauth2_refresh_tokens::table + .find(refresh_token) + .inner_join(oauth2_access_tokens::table) + .select(<(oauth2::RefreshToken, oauth2::AccessToken)>::as_select()) + .get_result::<(oauth2::RefreshToken, oauth2::AccessToken)>(db_conn) + .await + }) + .map_err(|_| ())?; + + let (access_token, refresh_token) = with_transaction!(self.db_pool, |tx| { + let new_access_token = diesel::insert_into(oauth2_access_tokens::table) + .values(oauth2::NewAccessToken { + user_id: access_token.user_id, + token: generate_secret().as_str(), + application_id: access_token.application_id, + scopes: access_token.scopes.as_str(), + expires_at: chrono_to_timestamp(grant.until), + }) + .get_result::(tx) + .await?; + + let refresh_token = diesel::update(&refresh_token) + .set(oauth2_refresh_tokens::access_token.eq(new_access_token.token.as_str())) + .get_result::(tx) + .await?; + + diesel::delete(&access_token).execute(tx).await?; + + Ok::<_, Error>((new_access_token, refresh_token)) + }) + .map_err(|_| ())?; Ok(RefreshedToken { token: access_token.token, @@ -125,22 +107,16 @@ impl Issuer for OAuthIssuer { } async fn recover_token(&mut self, access_token: &str) -> Result, ()> { - let oauth_data = self - .db_pool - .with_connection(|db_conn| { - async move { - oauth2_access_tokens::table - .find(access_token) - .inner_join(oauth2_applications::table) - .select(<(oauth2::AccessToken, oauth2::Application)>::as_select()) - .get_result::<(oauth2::AccessToken, oauth2::Application)>(db_conn) - .await - .optional() - } - .scoped() - }) - .await - .map_err(|_| ())?; + let oauth_data = with_connection!(self.db_pool, |db_conn| { + oauth2_access_tokens::table + .find(access_token) + .inner_join(oauth2_applications::table) + .select(<(oauth2::AccessToken, oauth2::Application)>::as_select()) + .get_result::<(oauth2::AccessToken, oauth2::Application)>(db_conn) + .await + .optional() + }) + .map_err(|_| ())?; let oauth_data = oauth_data.map(|(access_token, app)| { let scope = app.scopes.parse().unwrap(); @@ -165,23 +141,17 @@ impl Issuer for OAuthIssuer { } async fn recover_refresh(&mut self, refresh_token: &str) -> Result, ()> { - let oauth_data = self - .db_pool - .with_connection(|db_conn| { - async move { - oauth2_refresh_tokens::table - .find(refresh_token) - .inner_join(oauth2_access_tokens::table) - .inner_join(oauth2_applications::table) - .select(<(oauth2::AccessToken, oauth2::Application)>::as_select()) - .get_result::<(oauth2::AccessToken, oauth2::Application)>(db_conn) - .await - .optional() - } - .scoped() - }) - .await - .map_err(|_| ())?; + let oauth_data = with_connection!(self.db_pool, |db_conn| { + oauth2_refresh_tokens::table + .find(refresh_token) + .inner_join(oauth2_access_tokens::table) + .inner_join(oauth2_applications::table) + .select(<(oauth2::AccessToken, oauth2::Application)>::as_select()) + .get_result::<(oauth2::AccessToken, oauth2::Application)>(db_conn) + .await + .optional() + }) + .map_err(|_| ())?; let oauth_data = oauth_data.map(|(access_token, app)| { let scope = access_token.scopes.parse().unwrap(); diff --git a/kitsune/src/oauth2/mod.rs b/kitsune/src/oauth2/mod.rs index 5df192737..f171ecb3c 100644 --- a/kitsune/src/oauth2/mod.rs +++ b/kitsune/src/oauth2/mod.rs @@ -8,12 +8,11 @@ use iso8601_timestamp::Timestamp; use kitsune_db::{ model::oauth2, schema::{oauth2_applications, oauth2_authorization_codes}, - PgPool, + with_connection, PgPool, }; use kitsune_url::UrlService; use kitsune_util::generate_secret; use oxide_auth::endpoint::Scope; -use scoped_futures::ScopedFutureExt; use speedy_uuid::Uuid; use std::str::{self, FromStr}; use strum::{AsRefStr, EnumIter, EnumMessage, EnumString}; @@ -94,22 +93,20 @@ pub struct OAuth2Service { impl OAuth2Service { pub async fn create_app(&self, create_app: CreateApp) -> Result { let secret = generate_secret(); - self.db_pool - .with_connection(|db_conn| { - diesel::insert_into(oauth2_applications::table) - .values(oauth2::NewApplication { - id: Uuid::now_v7(), - secret: secret.as_str(), - name: create_app.name.as_str(), - redirect_uri: create_app.redirect_uris.as_str(), - scopes: "", - website: None, - }) - .get_result(db_conn) - .scoped() - }) - .await - .map_err(Error::from) + with_connection!(self.db_pool, |db_conn| { + diesel::insert_into(oauth2_applications::table) + .values(oauth2::NewApplication { + id: Uuid::now_v7(), + secret: secret.as_str(), + name: create_app.name.as_str(), + redirect_uri: create_app.redirect_uris.as_str(), + scopes: "", + website: None, + }) + .get_result(db_conn) + .await + }) + .map_err(Error::from) } pub async fn create_authorisation_code_response( @@ -124,9 +121,8 @@ impl OAuth2Service { let secret = generate_secret(); let scopes = scopes.to_string(); - let authorization_code: oauth2::AuthorizationCode = self - .db_pool - .with_connection(|db_conn| { + let authorization_code: oauth2::AuthorizationCode = + with_connection!(self.db_pool, |db_conn| { diesel::insert_into(oauth2_authorization_codes::table) .values(oauth2::NewAuthorizationCode { code: secret.as_str(), @@ -136,9 +132,8 @@ impl OAuth2Service { expires_at: Timestamp::now_utc() + AUTH_TOKEN_VALID_DURATION, }) .get_result(db_conn) - .scoped() - }) - .await?; + .await + })?; if application.redirect_uri == SHOW_TOKEN_URI { Ok(ShowTokenPage { diff --git a/kitsune/src/oauth2/registrar.rs b/kitsune/src/oauth2/registrar.rs index de35904f6..51dd26145 100644 --- a/kitsune/src/oauth2/registrar.rs +++ b/kitsune/src/oauth2/registrar.rs @@ -1,13 +1,12 @@ use async_trait::async_trait; use diesel::{ExpressionMethods, OptionalExtension, QueryDsl}; use diesel_async::RunQueryDsl; -use kitsune_db::{model::oauth2, schema::oauth2_applications, PgPool}; +use kitsune_db::{model::oauth2, schema::oauth2_applications, with_connection, PgPool}; use oxide_auth::{ endpoint::{PreGrant, Scope}, primitives::registrar::{BoundClient, ClientUrl, ExactUrl, RegisteredUrl, RegistrarError}, }; use oxide_auth_async::primitives::Registrar; -use scoped_futures::ScopedFutureExt; use speedy_uuid::Uuid; use std::{ borrow::Cow, @@ -47,22 +46,16 @@ impl Registrar for OAuthRegistrar { .parse() .map_err(|_| RegistrarError::PrimitiveError)?; - let client = self - .db_pool - .with_connection(|db_conn| { - async move { - oauth2_applications::table - .find(client_id) - .filter(oauth2_applications::redirect_uri.eq(client.redirect_uri.as_str())) - .get_result::(db_conn) - .await - .optional() - } - .scoped() - }) - .await - .map_err(|_| RegistrarError::PrimitiveError)? - .ok_or(RegistrarError::Unspecified)?; + let client = with_connection!(self.db_pool, |db_conn| { + oauth2_applications::table + .find(client_id) + .filter(oauth2_applications::redirect_uri.eq(client.redirect_uri.as_str())) + .get_result::(db_conn) + .await + .optional() + }) + .map_err(|_| RegistrarError::PrimitiveError)? + .ok_or(RegistrarError::Unspecified)?; let client_id = client.id.to_string(); let redirect_uri = ExactUrl::new(client.redirect_uri) @@ -111,20 +104,15 @@ impl Registrar for OAuthRegistrar { client_query = client_query.filter(oauth2_applications::secret.eq(passphrase)); } - self.db_pool - .with_connection(|db_conn| { - async move { - client_query - .select(oauth2_applications::id) - .execute(db_conn) - .await - .optional() - } - .scoped() - }) - .await - .map_err(|_| RegistrarError::PrimitiveError)? - .map(|_| ()) - .ok_or(RegistrarError::Unspecified) + with_connection!(self.db_pool, |db_conn| { + client_query + .select(oauth2_applications::id) + .execute(db_conn) + .await + .optional() + }) + .map_err(|_| RegistrarError::PrimitiveError)? + .map(|_| ()) + .ok_or(RegistrarError::Unspecified) } } diff --git a/kitsune/src/oauth2/solicitor.rs b/kitsune/src/oauth2/solicitor.rs index aa664e172..f9747ad50 100644 --- a/kitsune/src/oauth2/solicitor.rs +++ b/kitsune/src/oauth2/solicitor.rs @@ -4,11 +4,10 @@ use async_trait::async_trait; use cursiv::CsrfHandle; use diesel::{OptionalExtension, QueryDsl}; use diesel_async::RunQueryDsl; -use kitsune_db::{model::user::User, schema::oauth2_applications, PgPool}; +use kitsune_db::{model::user::User, schema::oauth2_applications, with_connection, PgPool}; use oxide_auth::endpoint::{OAuthError, OwnerConsent, QueryParameter, Solicitation, WebRequest}; use oxide_auth_async::endpoint::OwnerSolicitor; use oxide_auth_axum::{OAuthRequest, OAuthResponse, WebError}; -use scoped_futures::ScopedFutureExt; use speedy_uuid::Uuid; use std::{borrow::Cow, str::FromStr}; use strum::EnumMessage; @@ -80,22 +79,16 @@ impl OAuthOwnerSolicitor { .parse() .map_err(|_| WebError::Endpoint(OAuthError::BadRequest))?; - let app_name = self - .db_pool - .with_connection(|db_conn| { - async move { - oauth2_applications::table - .find(client_id) - .select(oauth2_applications::name) - .get_result::(db_conn) - .await - .optional() - } - .scoped() - }) - .await - .map_err(|_| WebError::InternalError(None))? - .ok_or(WebError::Endpoint(OAuthError::DenySilently))?; + let app_name = with_connection!(self.db_pool, |db_conn| { + oauth2_applications::table + .find(client_id) + .select(oauth2_applications::name) + .get_result::(db_conn) + .await + .optional() + }) + .map_err(|_| WebError::InternalError(None))? + .ok_or(WebError::Endpoint(OAuthError::DenySilently))?; let scopes = solicitation .pre_grant()