diff --git a/Cargo.lock b/Cargo.lock index 6386f9685..4d40c4589 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -626,10 +626,11 @@ dependencies = [ "clap", "conduit_api", "conduit_core", + "conduit_database", "conduit_macros", "conduit_service", "const-str", - "futures-util", + "futures", "log", "ruma", "serde_json", @@ -652,7 +653,7 @@ dependencies = [ "conduit_database", "conduit_service", "const-str", - "futures-util", + "futures", "hmac", "http", "http-body-util", @@ -689,6 +690,7 @@ dependencies = [ "cyborgtime", "either", "figment", + "futures", "hardened_malloc-rs", "http", "http-body-util", @@ -707,6 +709,7 @@ dependencies = [ "serde", "serde_json", "serde_regex", + "serde_yaml", "thiserror", "tikv-jemalloc-ctl", "tikv-jemalloc-sys", @@ -724,10 +727,14 @@ dependencies = [ name = "conduit_database" version = "0.4.7" dependencies = [ + "arrayvec", "conduit_core", "const-str", + "futures", "log", "rust-rocksdb-uwu", + "serde", + "serde_json", "tokio", "tracing", ] @@ -784,7 +791,7 @@ dependencies = [ "conduit_core", "conduit_database", "const-str", - "futures-util", + "futures", "hickory-resolver", "http", "image", @@ -1283,6 +1290,20 @@ dependencies = [ "new_debug_unreachable", ] +[[package]] +name = "futures" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" +dependencies = [ + "futures-channel", + "futures-core", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + [[package]] name = "futures-channel" version = "0.3.31" @@ -1345,6 +1366,7 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" dependencies = [ + "futures-channel", "futures-core", "futures-io", "futures-macro", @@ -2953,7 +2975,7 @@ dependencies = [ [[package]] name = "ruma" version = "0.10.1" -source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd" +source = "git+https://github.com/girlbossceo/ruwuma?rev=ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e#ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e" dependencies = [ "assign", "js_int", @@ -2975,7 +2997,7 @@ dependencies = [ [[package]] name = "ruma-appservice-api" version = "0.10.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd" +source = "git+https://github.com/girlbossceo/ruwuma?rev=ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e#ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e" dependencies = [ "js_int", "ruma-common", @@ -2987,7 +3009,7 @@ dependencies = [ [[package]] name = "ruma-client-api" version = "0.18.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd" +source = "git+https://github.com/girlbossceo/ruwuma?rev=ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e#ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e" dependencies = [ "as_variant", "assign", @@ -3010,7 +3032,7 @@ dependencies = [ [[package]] name = "ruma-common" version = "0.13.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd" +source = "git+https://github.com/girlbossceo/ruwuma?rev=ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e#ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e" dependencies = [ "as_variant", "base64 0.22.1", @@ -3040,7 +3062,7 @@ dependencies = [ [[package]] name = "ruma-events" version = "0.28.1" -source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd" +source = "git+https://github.com/girlbossceo/ruwuma?rev=ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e#ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e" dependencies = [ "as_variant", "indexmap 2.6.0", @@ -3064,7 +3086,7 @@ dependencies = [ [[package]] name = "ruma-federation-api" version = "0.9.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd" +source = "git+https://github.com/girlbossceo/ruwuma?rev=ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e#ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e" dependencies = [ "bytes", "http", @@ -3082,7 +3104,7 @@ dependencies = [ [[package]] name = "ruma-identifiers-validation" version = "0.9.5" -source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd" +source = "git+https://github.com/girlbossceo/ruwuma?rev=ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e#ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e" dependencies = [ "js_int", "thiserror", @@ -3091,7 +3113,7 @@ dependencies = [ [[package]] name = "ruma-identity-service-api" version = "0.9.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd" +source = "git+https://github.com/girlbossceo/ruwuma?rev=ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e#ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e" dependencies = [ "js_int", "ruma-common", @@ -3101,7 +3123,7 @@ dependencies = [ [[package]] name = "ruma-macros" version = "0.13.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd" +source = "git+https://github.com/girlbossceo/ruwuma?rev=ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e#ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e" dependencies = [ "cfg-if", "once_cell", @@ -3117,7 +3139,7 @@ dependencies = [ [[package]] name = "ruma-push-gateway-api" version = "0.9.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd" +source = "git+https://github.com/girlbossceo/ruwuma?rev=ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e#ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e" dependencies = [ "js_int", "ruma-common", @@ -3129,7 +3151,7 @@ dependencies = [ [[package]] name = "ruma-server-util" version = "0.3.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd" +source = "git+https://github.com/girlbossceo/ruwuma?rev=ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e#ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e" dependencies = [ "headers", "http", @@ -3142,7 +3164,7 @@ dependencies = [ [[package]] name = "ruma-signatures" version = "0.15.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd" +source = "git+https://github.com/girlbossceo/ruwuma?rev=ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e#ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e" dependencies = [ "base64 0.22.1", "ed25519-dalek", @@ -3158,8 +3180,9 @@ dependencies = [ [[package]] name = "ruma-state-res" version = "0.11.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd" +source = "git+https://github.com/girlbossceo/ruwuma?rev=ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e#ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e" dependencies = [ + "futures-util", "itertools 0.12.1", "js_int", "ruma-common", diff --git a/Cargo.toml b/Cargo.toml index b75c49757..28e280cfd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -210,9 +210,10 @@ features = [ "string", ] -[workspace.dependencies.futures-util] +[workspace.dependencies.futures] version = "0.3.30" default-features = false +features = ["std"] [workspace.dependencies.tokio] version = "1.40.0" @@ -314,7 +315,7 @@ version = "0.1.2" [workspace.dependencies.ruma] git = "https://github.com/girlbossceo/ruwuma" #branch = "conduwuit-changes" -rev = "9900d0676564883cfade556d6e8da2a2c9061efd" +rev = "ade2f1daf0b1d9e8f7de81a24dca8925406e4d8e" features = [ "compat", "rand", @@ -463,7 +464,6 @@ version = "1.0.36" [workspace.dependencies.proc-macro2] version = "1.0.89" - # # Patches # @@ -828,6 +828,7 @@ missing_panics_doc = { level = "allow", priority = 1 } module_name_repetitions = { level = "allow", priority = 1 } no_effect_underscore_binding = { level = "allow", priority = 1 } similar_names = { level = "allow", priority = 1 } +single_match_else = { level = "allow", priority = 1 } struct_field_names = { level = "allow", priority = 1 } unnecessary_wraps = { level = "allow", priority = 1 } unused_async = { level = "allow", priority = 1 } diff --git a/clippy.toml b/clippy.toml index c942b93c7..08641fcc1 100644 --- a/clippy.toml +++ b/clippy.toml @@ -2,6 +2,6 @@ array-size-threshold = 4096 cognitive-complexity-threshold = 94 # TODO reduce me ALARA excessive-nesting-threshold = 11 # TODO reduce me to 4 or 5 future-size-threshold = 7745 # TODO reduce me ALARA -stack-size-threshold = 144000 # reduce me ALARA +stack-size-threshold = 196608 # reduce me ALARA too-many-lines-threshold = 700 # TODO reduce me to <= 100 type-complexity-threshold = 250 # reduce me to ~200 diff --git a/src/admin/Cargo.toml b/src/admin/Cargo.toml index d756b3cbd..f5cab4496 100644 --- a/src/admin/Cargo.toml +++ b/src/admin/Cargo.toml @@ -29,10 +29,11 @@ release_max_log_level = [ clap.workspace = true conduit-api.workspace = true conduit-core.workspace = true +conduit-database.workspace = true conduit-macros.workspace = true conduit-service.workspace = true const-str.workspace = true -futures-util.workspace = true +futures.workspace = true log.workspace = true ruma.workspace = true serde_json.workspace = true diff --git a/src/admin/check/commands.rs b/src/admin/check/commands.rs index 0a9830464..88fca462f 100644 --- a/src/admin/check/commands.rs +++ b/src/admin/check/commands.rs @@ -1,5 +1,6 @@ use conduit::Result; use conduit_macros::implement; +use futures::StreamExt; use ruma::events::room::message::RoomMessageEventContent; use crate::Command; @@ -10,14 +11,12 @@ use crate::Command; #[implement(Command, params = "<'_>")] pub(super) async fn check_all_users(&self) -> Result { let timer = tokio::time::Instant::now(); - let results = self.services.users.db.iter(); + let users = self.services.users.iter().collect::>().await; let query_time = timer.elapsed(); - let users = results.collect::>(); - let total = users.len(); - let err_count = users.iter().filter(|user| user.is_err()).count(); - let ok_count = users.iter().filter(|user| user.is_ok()).count(); + let err_count = users.iter().filter(|_user| false).count(); + let ok_count = users.iter().filter(|_user| true).count(); let message = format!( "Database query completed in {query_time:?}:\n\n```\nTotal entries: {total:?}\nFailure/Invalid user count: \ diff --git a/src/admin/debug/commands.rs b/src/admin/debug/commands.rs index 2d9670064..350e08c6a 100644 --- a/src/admin/debug/commands.rs +++ b/src/admin/debug/commands.rs @@ -7,6 +7,7 @@ use std::{ use api::client::validate_and_add_event_id; use conduit::{debug, debug_error, err, info, trace, utils, warn, Error, PduEvent, Result}; +use futures::StreamExt; use ruma::{ api::{client::error::ErrorKind, federation::event::get_room_state}, events::room::message::RoomMessageEventContent, @@ -26,32 +27,32 @@ pub(super) async fn echo(&self, message: Vec) -> Result) -> Result { - let event_id = Arc::::from(event_id); - if let Some(event) = self.services.rooms.timeline.get_pdu_json(&event_id)? { - let room_id_str = event - .get("room_id") - .and_then(|val| val.as_str()) - .ok_or_else(|| Error::bad_database("Invalid event in database"))?; - - let room_id = <&RoomId>::try_from(room_id_str) - .map_err(|_| Error::bad_database("Invalid room id field in event in database"))?; - - let start = Instant::now(); - let count = self - .services - .rooms - .auth_chain - .event_ids_iter(room_id, vec![event_id]) - .await? - .count(); - - let elapsed = start.elapsed(); - Ok(RoomMessageEventContent::text_plain(format!( - "Loaded auth chain with length {count} in {elapsed:?}" - ))) - } else { - Ok(RoomMessageEventContent::text_plain("Event not found.")) - } + let Ok(event) = self.services.rooms.timeline.get_pdu_json(&event_id).await else { + return Ok(RoomMessageEventContent::notice_plain("Event not found.")); + }; + + let room_id_str = event + .get("room_id") + .and_then(|val| val.as_str()) + .ok_or_else(|| Error::bad_database("Invalid event in database"))?; + + let room_id = <&RoomId>::try_from(room_id_str) + .map_err(|_| Error::bad_database("Invalid room id field in event in database"))?; + + let start = Instant::now(); + let count = self + .services + .rooms + .auth_chain + .event_ids_iter(room_id, &[&event_id]) + .await? + .count() + .await; + + let elapsed = start.elapsed(); + Ok(RoomMessageEventContent::text_plain(format!( + "Loaded auth chain with length {count} in {elapsed:?}" + ))) } #[admin_command] @@ -91,13 +92,16 @@ pub(super) async fn get_pdu(&self, event_id: Box) -> Result { + Ok(json) => { let json_text = serde_json::to_string_pretty(&json).expect("canonical json is valid json"); Ok(RoomMessageEventContent::notice_markdown(format!( "{}\n```json\n{}\n```", @@ -109,7 +113,7 @@ pub(super) async fn get_pdu(&self, event_id: Box) -> Result Ok(RoomMessageEventContent::text_plain("PDU not found locally.")), + Err(_) => Ok(RoomMessageEventContent::text_plain("PDU not found locally.")), } } @@ -157,7 +161,8 @@ pub(super) async fn get_remote_pdu_list( .send_message(RoomMessageEventContent::text_plain(format!( "Failed to get remote PDU, ignoring error: {e}" ))) - .await; + .await + .ok(); warn!("Failed to get remote PDU, ignoring error: {e}"); } else { success_count = success_count.saturating_add(1); @@ -215,7 +220,9 @@ pub(super) async fn get_remote_pdu( .services .rooms .event_handler - .parse_incoming_pdu(&response.pdu); + .parse_incoming_pdu(&response.pdu) + .await; + let (event_id, value, room_id) = match parsed_result { Ok(t) => t, Err(e) => { @@ -333,9 +340,12 @@ pub(super) async fn ping(&self, server: Box) -> Result Result { // Force E2EE device list updates for all users - for user_id in self.services.users.iter().filter_map(Result::ok) { - self.services.users.mark_device_key_update(&user_id)?; - } + self.services + .users + .stream() + .for_each(|user_id| self.services.users.mark_device_key_update(user_id)) + .await; + Ok(RoomMessageEventContent::text_plain( "Marked all devices for all users as having new keys to update", )) @@ -470,7 +480,8 @@ pub(super) async fn first_pdu_in_room(&self, room_id: Box) -> Result) -> Result) -> Result) -> Result> = HashMap::new(); let pub_key_map = RwLock::new(BTreeMap::new()); @@ -554,13 +570,21 @@ pub(super) async fn force_set_room_state_from_server( let mut events = Vec::with_capacity(remote_state_response.pdus.len()); for pdu in remote_state_response.pdus.clone() { - events.push(match self.services.rooms.event_handler.parse_incoming_pdu(&pdu) { - Ok(t) => t, - Err(e) => { - warn!("Could not parse PDU, ignoring: {e}"); - continue; + events.push( + match self + .services + .rooms + .event_handler + .parse_incoming_pdu(&pdu) + .await + { + Ok(t) => t, + Err(e) => { + warn!("Could not parse PDU, ignoring: {e}"); + continue; + }, }, - }); + ); } info!("Fetching required signing keys for all the state events we got"); @@ -587,13 +611,16 @@ pub(super) async fn force_set_room_state_from_server( self.services .rooms .outlier - .add_pdu_outlier(&event_id, &value)?; + .add_pdu_outlier(&event_id, &value); + if let Some(state_key) = &pdu.state_key { let shortstatekey = self .services .rooms .short - .get_or_create_shortstatekey(&pdu.kind.to_string().into(), state_key)?; + .get_or_create_shortstatekey(&pdu.kind.to_string().into(), state_key) + .await; + state.insert(shortstatekey, pdu.event_id.clone()); } } @@ -611,7 +638,7 @@ pub(super) async fn force_set_room_state_from_server( self.services .rooms .outlier - .add_pdu_outlier(&event_id, &value)?; + .add_pdu_outlier(&event_id, &value); } let new_room_state = self @@ -626,7 +653,8 @@ pub(super) async fn force_set_room_state_from_server( .services .rooms .state_compressor - .save_state(room_id.clone().as_ref(), new_room_state)?; + .save_state(room_id.clone().as_ref(), new_room_state) + .await?; let state_lock = self.services.rooms.state.mutex.lock(&room_id).await; self.services @@ -642,7 +670,8 @@ pub(super) async fn force_set_room_state_from_server( self.services .rooms .state_cache - .update_joined_count(&room_id)?; + .update_joined_count(&room_id) + .await; drop(state_lock); @@ -656,7 +685,7 @@ pub(super) async fn get_signing_keys( &self, server_name: Option>, _cached: bool, ) -> Result { let server_name = server_name.unwrap_or_else(|| self.services.server.config.server_name.clone().into()); - let signing_keys = self.services.globals.signing_keys_for(&server_name)?; + let signing_keys = self.services.globals.signing_keys_for(&server_name).await?; Ok(RoomMessageEventContent::notice_markdown(format!( "```rs\n{signing_keys:#?}\n```" @@ -674,7 +703,7 @@ pub(super) async fn get_verify_keys( if cached { writeln!(out, "| Key ID | VerifyKey |")?; writeln!(out, "| --- | --- |")?; - for (key_id, verify_key) in self.services.globals.verify_keys_for(&server_name)? { + for (key_id, verify_key) in self.services.globals.verify_keys_for(&server_name).await? { writeln!(out, "| {key_id} | {verify_key:?} |")?; } diff --git a/src/admin/federation/commands.rs b/src/admin/federation/commands.rs index 8917a46b9..ce95ac01b 100644 --- a/src/admin/federation/commands.rs +++ b/src/admin/federation/commands.rs @@ -1,19 +1,20 @@ use std::fmt::Write; use conduit::Result; +use futures::StreamExt; use ruma::{events::room::message::RoomMessageEventContent, OwnedRoomId, RoomId, ServerName, UserId}; use crate::{admin_command, escape_html, get_room_info}; #[admin_command] pub(super) async fn disable_room(&self, room_id: Box) -> Result { - self.services.rooms.metadata.disable_room(&room_id, true)?; + self.services.rooms.metadata.disable_room(&room_id, true); Ok(RoomMessageEventContent::text_plain("Room disabled.")) } #[admin_command] pub(super) async fn enable_room(&self, room_id: Box) -> Result { - self.services.rooms.metadata.disable_room(&room_id, false)?; + self.services.rooms.metadata.disable_room(&room_id, false); Ok(RoomMessageEventContent::text_plain("Room enabled.")) } @@ -85,7 +86,7 @@ pub(super) async fn remote_user_in_rooms(&self, user_id: Box) -> Result< )); } - if !self.services.users.exists(&user_id)? { + if !self.services.users.exists(&user_id).await { return Ok(RoomMessageEventContent::text_plain( "Remote user does not exist in our database.", )); @@ -96,9 +97,9 @@ pub(super) async fn remote_user_in_rooms(&self, user_id: Box) -> Result< .rooms .state_cache .rooms_joined(&user_id) - .filter_map(Result::ok) - .map(|room_id| get_room_info(self.services, &room_id)) - .collect(); + .then(|room_id| get_room_info(self.services, room_id)) + .collect() + .await; if rooms.is_empty() { return Ok(RoomMessageEventContent::text_plain("User is not in any rooms.")); diff --git a/src/admin/media/commands.rs b/src/admin/media/commands.rs index 3c4bf2ef8..82ac162eb 100644 --- a/src/admin/media/commands.rs +++ b/src/admin/media/commands.rs @@ -36,7 +36,7 @@ pub(super) async fn delete( let mut mxc_urls = Vec::with_capacity(4); // parsing the PDU for any MXC URLs begins here - if let Some(event_json) = self.services.rooms.timeline.get_pdu_json(&event_id)? { + if let Ok(event_json) = self.services.rooms.timeline.get_pdu_json(&event_id).await { if let Some(content_key) = event_json.get("content") { debug!("Event ID has \"content\"."); let content_obj = content_key.as_object(); @@ -300,7 +300,7 @@ pub(super) async fn delete_all_from_server( #[admin_command] pub(super) async fn get_file_info(&self, mxc: OwnedMxcUri) -> Result { let mxc: Mxc<'_> = mxc.as_str().try_into()?; - let metadata = self.services.media.get_metadata(&mxc); + let metadata = self.services.media.get_metadata(&mxc).await; Ok(RoomMessageEventContent::notice_markdown(format!("```\n{metadata:#?}\n```"))) } diff --git a/src/admin/processor.rs b/src/admin/processor.rs index 4f60f56e9..3c1895ffd 100644 --- a/src/admin/processor.rs +++ b/src/admin/processor.rs @@ -17,7 +17,7 @@ use conduit::{ utils::string::{collect_stream, common_prefix}, warn, Error, Result, }; -use futures_util::future::FutureExt; +use futures::future::FutureExt; use ruma::{ events::{ relation::InReplyTo, diff --git a/src/admin/query/account_data.rs b/src/admin/query/account_data.rs index e18c298a3..896bf95cf 100644 --- a/src/admin/query/account_data.rs +++ b/src/admin/query/account_data.rs @@ -44,7 +44,8 @@ pub(super) async fn process(subcommand: AccountDataCommand, context: &Command<'_ let timer = tokio::time::Instant::now(); let results = services .account_data - .changes_since(room_id.as_deref(), &user_id, since)?; + .changes_since(room_id.as_deref(), &user_id, since) + .await?; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -59,7 +60,8 @@ pub(super) async fn process(subcommand: AccountDataCommand, context: &Command<'_ let timer = tokio::time::Instant::now(); let results = services .account_data - .get(room_id.as_deref(), &user_id, kind)?; + .get(room_id.as_deref(), &user_id, kind) + .await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( diff --git a/src/admin/query/appservice.rs b/src/admin/query/appservice.rs index 683c228f7..4b97ef4eb 100644 --- a/src/admin/query/appservice.rs +++ b/src/admin/query/appservice.rs @@ -29,7 +29,9 @@ pub(super) async fn process(subcommand: AppserviceCommand, context: &Command<'_> let results = services .appservice .db - .get_registration(appservice_id.as_ref()); + .get_registration(appservice_id.as_ref()) + .await; + let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -38,7 +40,7 @@ pub(super) async fn process(subcommand: AppserviceCommand, context: &Command<'_> }, AppserviceCommand::All => { let timer = tokio::time::Instant::now(); - let results = services.appservice.all(); + let results = services.appservice.all().await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( diff --git a/src/admin/query/globals.rs b/src/admin/query/globals.rs index 5f271c2c4..150a213cd 100644 --- a/src/admin/query/globals.rs +++ b/src/admin/query/globals.rs @@ -29,7 +29,7 @@ pub(super) async fn process(subcommand: GlobalsCommand, context: &Command<'_>) - match subcommand { GlobalsCommand::DatabaseVersion => { let timer = tokio::time::Instant::now(); - let results = services.globals.db.database_version(); + let results = services.globals.db.database_version().await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -47,7 +47,7 @@ pub(super) async fn process(subcommand: GlobalsCommand, context: &Command<'_>) - }, GlobalsCommand::LastCheckForUpdatesId => { let timer = tokio::time::Instant::now(); - let results = services.updates.last_check_for_updates_id(); + let results = services.updates.last_check_for_updates_id().await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -67,7 +67,7 @@ pub(super) async fn process(subcommand: GlobalsCommand, context: &Command<'_>) - origin, } => { let timer = tokio::time::Instant::now(); - let results = services.globals.db.verify_keys_for(&origin); + let results = services.globals.db.verify_keys_for(&origin).await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( diff --git a/src/admin/query/presence.rs b/src/admin/query/presence.rs index 145ecd9b1..6189270cc 100644 --- a/src/admin/query/presence.rs +++ b/src/admin/query/presence.rs @@ -1,5 +1,6 @@ use clap::Subcommand; use conduit::Result; +use futures::StreamExt; use ruma::{events::room::message::RoomMessageEventContent, UserId}; use crate::Command; @@ -30,7 +31,7 @@ pub(super) async fn process(subcommand: PresenceCommand, context: &Command<'_>) user_id, } => { let timer = tokio::time::Instant::now(); - let results = services.presence.db.get_presence(&user_id)?; + let results = services.presence.db.get_presence(&user_id).await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -42,7 +43,7 @@ pub(super) async fn process(subcommand: PresenceCommand, context: &Command<'_>) } => { let timer = tokio::time::Instant::now(); let results = services.presence.db.presence_since(since); - let presence_since: Vec<(_, _, _)> = results.collect(); + let presence_since: Vec<(_, _, _)> = results.collect().await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( diff --git a/src/admin/query/pusher.rs b/src/admin/query/pusher.rs index 637c57b65..a1bd32f99 100644 --- a/src/admin/query/pusher.rs +++ b/src/admin/query/pusher.rs @@ -21,7 +21,7 @@ pub(super) async fn process(subcommand: PusherCommand, context: &Command<'_>) -> user_id, } => { let timer = tokio::time::Instant::now(); - let results = services.pusher.get_pushers(&user_id)?; + let results = services.pusher.get_pushers(&user_id).await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( diff --git a/src/admin/query/room_alias.rs b/src/admin/query/room_alias.rs index 1809e26a0..382e4a784 100644 --- a/src/admin/query/room_alias.rs +++ b/src/admin/query/room_alias.rs @@ -1,5 +1,6 @@ use clap::Subcommand; use conduit::Result; +use futures::StreamExt; use ruma::{events::room::message::RoomMessageEventContent, RoomAliasId, RoomId}; use crate::Command; @@ -31,7 +32,7 @@ pub(super) async fn process(subcommand: RoomAliasCommand, context: &Command<'_>) alias, } => { let timer = tokio::time::Instant::now(); - let results = services.rooms.alias.resolve_local_alias(&alias); + let results = services.rooms.alias.resolve_local_alias(&alias).await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -42,8 +43,13 @@ pub(super) async fn process(subcommand: RoomAliasCommand, context: &Command<'_>) room_id, } => { let timer = tokio::time::Instant::now(); - let results = services.rooms.alias.local_aliases_for_room(&room_id); - let aliases: Vec<_> = results.collect(); + let aliases: Vec<_> = services + .rooms + .alias + .local_aliases_for_room(&room_id) + .map(ToOwned::to_owned) + .collect() + .await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -52,8 +58,13 @@ pub(super) async fn process(subcommand: RoomAliasCommand, context: &Command<'_>) }, RoomAliasCommand::AllLocalAliases => { let timer = tokio::time::Instant::now(); - let results = services.rooms.alias.all_local_aliases(); - let aliases: Vec<_> = results.collect(); + let aliases = services + .rooms + .alias + .all_local_aliases() + .map(|(room_id, alias)| (room_id.to_owned(), alias.to_owned())) + .collect::>() + .await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( diff --git a/src/admin/query/room_state_cache.rs b/src/admin/query/room_state_cache.rs index 4215cf8d6..e32517fb1 100644 --- a/src/admin/query/room_state_cache.rs +++ b/src/admin/query/room_state_cache.rs @@ -1,5 +1,6 @@ use clap::Subcommand; use conduit::Result; +use futures::StreamExt; use ruma::{events::room::message::RoomMessageEventContent, RoomId, ServerName, UserId}; use crate::Command; @@ -86,7 +87,11 @@ pub(super) async fn process( room_id, } => { let timer = tokio::time::Instant::now(); - let result = services.rooms.state_cache.server_in_room(&server, &room_id); + let result = services + .rooms + .state_cache + .server_in_room(&server, &room_id) + .await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -97,7 +102,13 @@ pub(super) async fn process( room_id, } => { let timer = tokio::time::Instant::now(); - let results: Result> = services.rooms.state_cache.room_servers(&room_id).collect(); + let results: Vec<_> = services + .rooms + .state_cache + .room_servers(&room_id) + .map(ToOwned::to_owned) + .collect() + .await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -108,7 +119,13 @@ pub(super) async fn process( server, } => { let timer = tokio::time::Instant::now(); - let results: Result> = services.rooms.state_cache.server_rooms(&server).collect(); + let results: Vec<_> = services + .rooms + .state_cache + .server_rooms(&server) + .map(ToOwned::to_owned) + .collect() + .await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -119,7 +136,13 @@ pub(super) async fn process( room_id, } => { let timer = tokio::time::Instant::now(); - let results: Result> = services.rooms.state_cache.room_members(&room_id).collect(); + let results: Vec<_> = services + .rooms + .state_cache + .room_members(&room_id) + .map(ToOwned::to_owned) + .collect() + .await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -134,7 +157,9 @@ pub(super) async fn process( .rooms .state_cache .local_users_in_room(&room_id) - .collect(); + .map(ToOwned::to_owned) + .collect() + .await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -149,7 +174,9 @@ pub(super) async fn process( .rooms .state_cache .active_local_users_in_room(&room_id) - .collect(); + .map(ToOwned::to_owned) + .collect() + .await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -160,7 +187,7 @@ pub(super) async fn process( room_id, } => { let timer = tokio::time::Instant::now(); - let results = services.rooms.state_cache.room_joined_count(&room_id); + let results = services.rooms.state_cache.room_joined_count(&room_id).await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -171,7 +198,11 @@ pub(super) async fn process( room_id, } => { let timer = tokio::time::Instant::now(); - let results = services.rooms.state_cache.room_invited_count(&room_id); + let results = services + .rooms + .state_cache + .room_invited_count(&room_id) + .await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -182,11 +213,13 @@ pub(super) async fn process( room_id, } => { let timer = tokio::time::Instant::now(); - let results: Result> = services + let results: Vec<_> = services .rooms .state_cache .room_useroncejoined(&room_id) - .collect(); + .map(ToOwned::to_owned) + .collect() + .await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -197,11 +230,13 @@ pub(super) async fn process( room_id, } => { let timer = tokio::time::Instant::now(); - let results: Result> = services + let results: Vec<_> = services .rooms .state_cache .room_members_invited(&room_id) - .collect(); + .map(ToOwned::to_owned) + .collect() + .await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -216,7 +251,8 @@ pub(super) async fn process( let results = services .rooms .state_cache - .get_invite_count(&room_id, &user_id); + .get_invite_count(&room_id, &user_id) + .await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -231,7 +267,8 @@ pub(super) async fn process( let results = services .rooms .state_cache - .get_left_count(&room_id, &user_id); + .get_left_count(&room_id, &user_id) + .await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -242,7 +279,13 @@ pub(super) async fn process( user_id, } => { let timer = tokio::time::Instant::now(); - let results: Result> = services.rooms.state_cache.rooms_joined(&user_id).collect(); + let results: Vec<_> = services + .rooms + .state_cache + .rooms_joined(&user_id) + .map(ToOwned::to_owned) + .collect() + .await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -253,7 +296,12 @@ pub(super) async fn process( user_id, } => { let timer = tokio::time::Instant::now(); - let results: Result> = services.rooms.state_cache.rooms_invited(&user_id).collect(); + let results: Vec<_> = services + .rooms + .state_cache + .rooms_invited(&user_id) + .collect() + .await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -264,7 +312,12 @@ pub(super) async fn process( user_id, } => { let timer = tokio::time::Instant::now(); - let results: Result> = services.rooms.state_cache.rooms_left(&user_id).collect(); + let results: Vec<_> = services + .rooms + .state_cache + .rooms_left(&user_id) + .collect() + .await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -276,7 +329,11 @@ pub(super) async fn process( room_id, } => { let timer = tokio::time::Instant::now(); - let results = services.rooms.state_cache.invite_state(&user_id, &room_id); + let results = services + .rooms + .state_cache + .invite_state(&user_id, &room_id) + .await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( diff --git a/src/admin/query/sending.rs b/src/admin/query/sending.rs index 6d54bddfd..eaab1f5ee 100644 --- a/src/admin/query/sending.rs +++ b/src/admin/query/sending.rs @@ -1,5 +1,6 @@ use clap::Subcommand; use conduit::Result; +use futures::StreamExt; use ruma::{events::room::message::RoomMessageEventContent, ServerName, UserId}; use service::sending::Destination; @@ -68,7 +69,7 @@ pub(super) async fn process(subcommand: SendingCommand, context: &Command<'_>) - SendingCommand::ActiveRequests => { let timer = tokio::time::Instant::now(); let results = services.sending.db.active_requests(); - let active_requests: Result> = results.collect(); + let active_requests = results.collect::>().await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -133,7 +134,7 @@ pub(super) async fn process(subcommand: SendingCommand, context: &Command<'_>) - }, }; - let queued_requests = results.collect::>>(); + let queued_requests = results.collect::>().await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -199,7 +200,7 @@ pub(super) async fn process(subcommand: SendingCommand, context: &Command<'_>) - }, }; - let active_requests = results.collect::>>(); + let active_requests = results.collect::>().await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -210,7 +211,7 @@ pub(super) async fn process(subcommand: SendingCommand, context: &Command<'_>) - server_name, } => { let timer = tokio::time::Instant::now(); - let results = services.sending.db.get_latest_educount(&server_name); + let results = services.sending.db.get_latest_educount(&server_name).await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( diff --git a/src/admin/query/users.rs b/src/admin/query/users.rs index fee12fbfc..0792e4840 100644 --- a/src/admin/query/users.rs +++ b/src/admin/query/users.rs @@ -1,29 +1,344 @@ use clap::Subcommand; use conduit::Result; -use ruma::events::room::message::RoomMessageEventContent; +use futures::stream::StreamExt; +use ruma::{events::room::message::RoomMessageEventContent, OwnedDeviceId, OwnedRoomId, OwnedUserId}; -use crate::Command; +use crate::{admin_command, admin_command_dispatch}; +#[admin_command_dispatch] #[derive(Debug, Subcommand)] /// All the getters and iterators from src/database/key_value/users.rs pub(crate) enum UsersCommand { - Iter, + CountUsers, + + IterUsers, + + PasswordHash { + user_id: OwnedUserId, + }, + + ListDevices { + user_id: OwnedUserId, + }, + + ListDevicesMetadata { + user_id: OwnedUserId, + }, + + GetDeviceMetadata { + user_id: OwnedUserId, + device_id: OwnedDeviceId, + }, + + GetDevicesVersion { + user_id: OwnedUserId, + }, + + CountOneTimeKeys { + user_id: OwnedUserId, + device_id: OwnedDeviceId, + }, + + GetDeviceKeys { + user_id: OwnedUserId, + device_id: OwnedDeviceId, + }, + + GetUserSigningKey { + user_id: OwnedUserId, + }, + + GetMasterKey { + user_id: OwnedUserId, + }, + + GetToDeviceEvents { + user_id: OwnedUserId, + device_id: OwnedDeviceId, + }, + + GetLatestBackup { + user_id: OwnedUserId, + }, + + GetLatestBackupVersion { + user_id: OwnedUserId, + }, + + GetBackupAlgorithm { + user_id: OwnedUserId, + version: String, + }, + + GetAllBackups { + user_id: OwnedUserId, + version: String, + }, + + GetRoomBackups { + user_id: OwnedUserId, + version: String, + room_id: OwnedRoomId, + }, + + GetBackupSession { + user_id: OwnedUserId, + version: String, + room_id: OwnedRoomId, + session_id: String, + }, +} + +#[admin_command] +async fn get_backup_session( + &self, user_id: OwnedUserId, version: String, room_id: OwnedRoomId, session_id: String, +) -> Result { + let timer = tokio::time::Instant::now(); + let result = self + .services + .key_backups + .get_session(&user_id, &version, &room_id, &session_id) + .await; + let query_time = timer.elapsed(); + + Ok(RoomMessageEventContent::notice_markdown(format!( + "Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```" + ))) +} + +#[admin_command] +async fn get_room_backups( + &self, user_id: OwnedUserId, version: String, room_id: OwnedRoomId, +) -> Result { + let timer = tokio::time::Instant::now(); + let result = self + .services + .key_backups + .get_room(&user_id, &version, &room_id) + .await; + let query_time = timer.elapsed(); + + Ok(RoomMessageEventContent::notice_markdown(format!( + "Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```" + ))) +} + +#[admin_command] +async fn get_all_backups(&self, user_id: OwnedUserId, version: String) -> Result { + let timer = tokio::time::Instant::now(); + let result = self.services.key_backups.get_all(&user_id, &version).await; + let query_time = timer.elapsed(); + + Ok(RoomMessageEventContent::notice_markdown(format!( + "Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```" + ))) +} + +#[admin_command] +async fn get_backup_algorithm(&self, user_id: OwnedUserId, version: String) -> Result { + let timer = tokio::time::Instant::now(); + let result = self + .services + .key_backups + .get_backup(&user_id, &version) + .await; + let query_time = timer.elapsed(); + + Ok(RoomMessageEventContent::notice_markdown(format!( + "Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```" + ))) +} + +#[admin_command] +async fn get_latest_backup_version(&self, user_id: OwnedUserId) -> Result { + let timer = tokio::time::Instant::now(); + let result = self + .services + .key_backups + .get_latest_backup_version(&user_id) + .await; + let query_time = timer.elapsed(); + + Ok(RoomMessageEventContent::notice_markdown(format!( + "Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```" + ))) +} + +#[admin_command] +async fn get_latest_backup(&self, user_id: OwnedUserId) -> Result { + let timer = tokio::time::Instant::now(); + let result = self.services.key_backups.get_latest_backup(&user_id).await; + let query_time = timer.elapsed(); + + Ok(RoomMessageEventContent::notice_markdown(format!( + "Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```" + ))) } -/// All the getters and iterators in key_value/users.rs -pub(super) async fn process(subcommand: UsersCommand, context: &Command<'_>) -> Result { - let services = context.services; +#[admin_command] +async fn iter_users(&self) -> Result { + let timer = tokio::time::Instant::now(); + let result: Vec = self.services.users.stream().map(Into::into).collect().await; + + let query_time = timer.elapsed(); + + Ok(RoomMessageEventContent::notice_markdown(format!( + "Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```" + ))) +} + +#[admin_command] +async fn count_users(&self) -> Result { + let timer = tokio::time::Instant::now(); + let result = self.services.users.count().await; + let query_time = timer.elapsed(); + + Ok(RoomMessageEventContent::notice_markdown(format!( + "Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```" + ))) +} + +#[admin_command] +async fn password_hash(&self, user_id: OwnedUserId) -> Result { + let timer = tokio::time::Instant::now(); + let result = self.services.users.password_hash(&user_id).await; + let query_time = timer.elapsed(); + + Ok(RoomMessageEventContent::notice_markdown(format!( + "Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```" + ))) +} + +#[admin_command] +async fn list_devices(&self, user_id: OwnedUserId) -> Result { + let timer = tokio::time::Instant::now(); + let devices = self + .services + .users + .all_device_ids(&user_id) + .map(ToOwned::to_owned) + .collect::>() + .await; + + let query_time = timer.elapsed(); + + Ok(RoomMessageEventContent::notice_markdown(format!( + "Query completed in {query_time:?}:\n\n```rs\n{devices:#?}\n```" + ))) +} + +#[admin_command] +async fn list_devices_metadata(&self, user_id: OwnedUserId) -> Result { + let timer = tokio::time::Instant::now(); + let devices = self + .services + .users + .all_devices_metadata(&user_id) + .collect::>() + .await; + let query_time = timer.elapsed(); + + Ok(RoomMessageEventContent::notice_markdown(format!( + "Query completed in {query_time:?}:\n\n```rs\n{devices:#?}\n```" + ))) +} + +#[admin_command] +async fn get_device_metadata(&self, user_id: OwnedUserId, device_id: OwnedDeviceId) -> Result { + let timer = tokio::time::Instant::now(); + let device = self + .services + .users + .get_device_metadata(&user_id, &device_id) + .await; + let query_time = timer.elapsed(); + + Ok(RoomMessageEventContent::notice_markdown(format!( + "Query completed in {query_time:?}:\n\n```rs\n{device:#?}\n```" + ))) +} + +#[admin_command] +async fn get_devices_version(&self, user_id: OwnedUserId) -> Result { + let timer = tokio::time::Instant::now(); + let device = self.services.users.get_devicelist_version(&user_id).await; + let query_time = timer.elapsed(); + + Ok(RoomMessageEventContent::notice_markdown(format!( + "Query completed in {query_time:?}:\n\n```rs\n{device:#?}\n```" + ))) +} + +#[admin_command] +async fn count_one_time_keys(&self, user_id: OwnedUserId, device_id: OwnedDeviceId) -> Result { + let timer = tokio::time::Instant::now(); + let result = self + .services + .users + .count_one_time_keys(&user_id, &device_id) + .await; + let query_time = timer.elapsed(); + + Ok(RoomMessageEventContent::notice_markdown(format!( + "Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```" + ))) +} + +#[admin_command] +async fn get_device_keys(&self, user_id: OwnedUserId, device_id: OwnedDeviceId) -> Result { + let timer = tokio::time::Instant::now(); + let result = self + .services + .users + .get_device_keys(&user_id, &device_id) + .await; + let query_time = timer.elapsed(); + + Ok(RoomMessageEventContent::notice_markdown(format!( + "Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```" + ))) +} + +#[admin_command] +async fn get_user_signing_key(&self, user_id: OwnedUserId) -> Result { + let timer = tokio::time::Instant::now(); + let result = self.services.users.get_user_signing_key(&user_id).await; + let query_time = timer.elapsed(); + + Ok(RoomMessageEventContent::notice_markdown(format!( + "Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```" + ))) +} + +#[admin_command] +async fn get_master_key(&self, user_id: OwnedUserId) -> Result { + let timer = tokio::time::Instant::now(); + let result = self + .services + .users + .get_master_key(None, &user_id, &|_| true) + .await; + let query_time = timer.elapsed(); + + Ok(RoomMessageEventContent::notice_markdown(format!( + "Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```" + ))) +} - match subcommand { - UsersCommand::Iter => { - let timer = tokio::time::Instant::now(); - let results = services.users.db.iter(); - let users = results.collect::>(); - let query_time = timer.elapsed(); +#[admin_command] +async fn get_to_device_events( + &self, user_id: OwnedUserId, device_id: OwnedDeviceId, +) -> Result { + let timer = tokio::time::Instant::now(); + let result = self + .services + .users + .get_to_device_events(&user_id, &device_id) + .collect::>() + .await; + let query_time = timer.elapsed(); - Ok(RoomMessageEventContent::notice_markdown(format!( - "Query completed in {query_time:?}:\n\n```rs\n{users:#?}\n```" - ))) - }, - } + Ok(RoomMessageEventContent::notice_markdown(format!( + "Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```" + ))) } diff --git a/src/admin/room/alias.rs b/src/admin/room/alias.rs index 415b8a083..1ccde47dc 100644 --- a/src/admin/room/alias.rs +++ b/src/admin/room/alias.rs @@ -2,7 +2,8 @@ use std::fmt::Write; use clap::Subcommand; use conduit::Result; -use ruma::{events::room::message::RoomMessageEventContent, RoomAliasId, RoomId}; +use futures::StreamExt; +use ruma::{events::room::message::RoomMessageEventContent, OwnedRoomAliasId, OwnedRoomId, RoomAliasId, RoomId}; use crate::{escape_html, Command}; @@ -66,8 +67,8 @@ pub(super) async fn process(command: RoomAliasCommand, context: &Command<'_>) -> force, room_id, .. - } => match (force, services.rooms.alias.resolve_local_alias(&room_alias)) { - (true, Ok(Some(id))) => match services + } => match (force, services.rooms.alias.resolve_local_alias(&room_alias).await) { + (true, Ok(id)) => match services .rooms .alias .set_alias(&room_alias, &room_id, server_user) @@ -77,10 +78,10 @@ pub(super) async fn process(command: RoomAliasCommand, context: &Command<'_>) -> ))), Err(err) => Ok(RoomMessageEventContent::text_plain(format!("Failed to remove alias: {err}"))), }, - (false, Ok(Some(id))) => Ok(RoomMessageEventContent::text_plain(format!( + (false, Ok(id)) => Ok(RoomMessageEventContent::text_plain(format!( "Refusing to overwrite in use alias for {id}, use -f or --force to overwrite" ))), - (_, Ok(None)) => match services + (_, Err(_)) => match services .rooms .alias .set_alias(&room_alias, &room_id, server_user) @@ -88,12 +89,11 @@ pub(super) async fn process(command: RoomAliasCommand, context: &Command<'_>) -> Ok(()) => Ok(RoomMessageEventContent::text_plain("Successfully set alias")), Err(err) => Ok(RoomMessageEventContent::text_plain(format!("Failed to remove alias: {err}"))), }, - (_, Err(err)) => Ok(RoomMessageEventContent::text_plain(format!("Unable to lookup alias: {err}"))), }, RoomAliasCommand::Remove { .. - } => match services.rooms.alias.resolve_local_alias(&room_alias) { - Ok(Some(id)) => match services + } => match services.rooms.alias.resolve_local_alias(&room_alias).await { + Ok(id) => match services .rooms .alias .remove_alias(&room_alias, server_user) @@ -102,15 +102,13 @@ pub(super) async fn process(command: RoomAliasCommand, context: &Command<'_>) -> Ok(()) => Ok(RoomMessageEventContent::text_plain(format!("Removed alias from {id}"))), Err(err) => Ok(RoomMessageEventContent::text_plain(format!("Failed to remove alias: {err}"))), }, - Ok(None) => Ok(RoomMessageEventContent::text_plain("Alias isn't in use.")), - Err(err) => Ok(RoomMessageEventContent::text_plain(format!("Unable to lookup alias: {err}"))), + Err(_) => Ok(RoomMessageEventContent::text_plain("Alias isn't in use.")), }, RoomAliasCommand::Which { .. - } => match services.rooms.alias.resolve_local_alias(&room_alias) { - Ok(Some(id)) => Ok(RoomMessageEventContent::text_plain(format!("Alias resolves to {id}"))), - Ok(None) => Ok(RoomMessageEventContent::text_plain("Alias isn't in use.")), - Err(err) => Ok(RoomMessageEventContent::text_plain(format!("Unable to lookup alias: {err}"))), + } => match services.rooms.alias.resolve_local_alias(&room_alias).await { + Ok(id) => Ok(RoomMessageEventContent::text_plain(format!("Alias resolves to {id}"))), + Err(_) => Ok(RoomMessageEventContent::text_plain("Alias isn't in use.")), }, RoomAliasCommand::List { .. @@ -121,67 +119,63 @@ pub(super) async fn process(command: RoomAliasCommand, context: &Command<'_>) -> room_id, } => { if let Some(room_id) = room_id { - let aliases = services + let aliases: Vec = services .rooms .alias .local_aliases_for_room(&room_id) - .collect::, _>>(); - match aliases { - Ok(aliases) => { - let plain_list = aliases.iter().fold(String::new(), |mut output, alias| { - writeln!(output, "- {alias}").expect("should be able to write to string buffer"); - output - }); - - let html_list = aliases.iter().fold(String::new(), |mut output, alias| { - writeln!(output, "
  • {}
  • ", escape_html(alias.as_ref())) - .expect("should be able to write to string buffer"); - output - }); - - let plain = format!("Aliases for {room_id}:\n{plain_list}"); - let html = format!("Aliases for {room_id}:\n
      {html_list}
    "); - Ok(RoomMessageEventContent::text_html(plain, html)) - }, - Err(err) => Ok(RoomMessageEventContent::text_plain(format!("Unable to list aliases: {err}"))), - } + .map(Into::into) + .collect() + .await; + + let plain_list = aliases.iter().fold(String::new(), |mut output, alias| { + writeln!(output, "- {alias}").expect("should be able to write to string buffer"); + output + }); + + let html_list = aliases.iter().fold(String::new(), |mut output, alias| { + writeln!(output, "
  • {}
  • ", escape_html(alias.as_ref())) + .expect("should be able to write to string buffer"); + output + }); + + let plain = format!("Aliases for {room_id}:\n{plain_list}"); + let html = format!("Aliases for {room_id}:\n
      {html_list}
    "); + Ok(RoomMessageEventContent::text_html(plain, html)) } else { let aliases = services .rooms .alias .all_local_aliases() - .collect::, _>>(); - match aliases { - Ok(aliases) => { - let server_name = services.globals.server_name(); - let plain_list = aliases - .iter() - .fold(String::new(), |mut output, (alias, id)| { - writeln!(output, "- `{alias}` -> #{id}:{server_name}") - .expect("should be able to write to string buffer"); - output - }); - - let html_list = aliases - .iter() - .fold(String::new(), |mut output, (alias, id)| { - writeln!( - output, - "
  • {} -> #{}:{}
  • ", - escape_html(alias.as_ref()), - escape_html(id.as_ref()), - server_name - ) - .expect("should be able to write to string buffer"); - output - }); - - let plain = format!("Aliases:\n{plain_list}"); - let html = format!("Aliases:\n
      {html_list}
    "); - Ok(RoomMessageEventContent::text_html(plain, html)) - }, - Err(e) => Ok(RoomMessageEventContent::text_plain(format!("Unable to list room aliases: {e}"))), - } + .map(|(room_id, localpart)| (room_id.into(), localpart.into())) + .collect::>() + .await; + + let server_name = services.globals.server_name(); + let plain_list = aliases + .iter() + .fold(String::new(), |mut output, (alias, id)| { + writeln!(output, "- `{alias}` -> #{id}:{server_name}") + .expect("should be able to write to string buffer"); + output + }); + + let html_list = aliases + .iter() + .fold(String::new(), |mut output, (alias, id)| { + writeln!( + output, + "
  • {} -> #{}:{}
  • ", + escape_html(alias.as_ref()), + escape_html(id), + server_name + ) + .expect("should be able to write to string buffer"); + output + }); + + let plain = format!("Aliases:\n{plain_list}"); + let html = format!("Aliases:\n
      {html_list}
    "); + Ok(RoomMessageEventContent::text_html(plain, html)) } }, } diff --git a/src/admin/room/commands.rs b/src/admin/room/commands.rs index 2adfa7d73..1c90a9983 100644 --- a/src/admin/room/commands.rs +++ b/src/admin/room/commands.rs @@ -1,11 +1,12 @@ use conduit::Result; -use ruma::events::room::message::RoomMessageEventContent; +use futures::StreamExt; +use ruma::{events::room::message::RoomMessageEventContent, OwnedRoomId}; use crate::{admin_command, get_room_info, PAGE_SIZE}; #[admin_command] pub(super) async fn list_rooms( - &self, page: Option, exclude_disabled: bool, exclude_banned: bool, no_details: bool, + &self, page: Option, _exclude_disabled: bool, _exclude_banned: bool, no_details: bool, ) -> Result { // TODO: i know there's a way to do this with clap, but i can't seem to find it let page = page.unwrap_or(1); @@ -14,37 +15,12 @@ pub(super) async fn list_rooms( .rooms .metadata .iter_ids() - .filter_map(|room_id| { - room_id - .ok() - .filter(|room_id| { - if exclude_disabled - && self - .services - .rooms - .metadata - .is_disabled(room_id) - .unwrap_or(false) - { - return false; - } + //.filter(|room_id| async { !exclude_disabled || !self.services.rooms.metadata.is_disabled(room_id).await }) + //.filter(|room_id| async { !exclude_banned || !self.services.rooms.metadata.is_banned(room_id).await }) + .then(|room_id| get_room_info(self.services, room_id)) + .collect::>() + .await; - if exclude_banned - && self - .services - .rooms - .metadata - .is_banned(room_id) - .unwrap_or(false) - { - return false; - } - - true - }) - .map(|room_id| get_room_info(self.services, &room_id)) - }) - .collect::>(); rooms.sort_by_key(|r| r.1); rooms.reverse(); @@ -74,3 +50,10 @@ pub(super) async fn list_rooms( Ok(RoomMessageEventContent::notice_markdown(output_plain)) } + +#[admin_command] +pub(super) async fn exists(&self, room_id: OwnedRoomId) -> Result { + let result = self.services.rooms.metadata.exists(&room_id).await; + + Ok(RoomMessageEventContent::notice_markdown(format!("{result}"))) +} diff --git a/src/admin/room/directory.rs b/src/admin/room/directory.rs index 7bba2eb7b..1080356a8 100644 --- a/src/admin/room/directory.rs +++ b/src/admin/room/directory.rs @@ -2,7 +2,8 @@ use std::fmt::Write; use clap::Subcommand; use conduit::Result; -use ruma::{events::room::message::RoomMessageEventContent, OwnedRoomId, RoomId}; +use futures::StreamExt; +use ruma::{events::room::message::RoomMessageEventContent, RoomId}; use crate::{escape_html, get_room_info, Command, PAGE_SIZE}; @@ -31,36 +32,37 @@ pub(super) async fn process(command: RoomDirectoryCommand, context: &Command<'_> match command { RoomDirectoryCommand::Publish { room_id, - } => match services.rooms.directory.set_public(&room_id) { - Ok(()) => Ok(RoomMessageEventContent::text_plain("Room published")), - Err(err) => Ok(RoomMessageEventContent::text_plain(format!("Unable to update room: {err}"))), + } => { + services.rooms.directory.set_public(&room_id); + Ok(RoomMessageEventContent::notice_plain("Room published")) }, RoomDirectoryCommand::Unpublish { room_id, - } => match services.rooms.directory.set_not_public(&room_id) { - Ok(()) => Ok(RoomMessageEventContent::text_plain("Room unpublished")), - Err(err) => Ok(RoomMessageEventContent::text_plain(format!("Unable to update room: {err}"))), + } => { + services.rooms.directory.set_not_public(&room_id); + Ok(RoomMessageEventContent::notice_plain("Room unpublished")) }, RoomDirectoryCommand::List { page, } => { // TODO: i know there's a way to do this with clap, but i can't seem to find it let page = page.unwrap_or(1); - let mut rooms = services + let mut rooms: Vec<_> = services .rooms .directory .public_rooms() - .filter_map(Result::ok) - .map(|id: OwnedRoomId| get_room_info(services, &id)) - .collect::>(); + .then(|room_id| get_room_info(services, room_id)) + .collect() + .await; + rooms.sort_by_key(|r| r.1); rooms.reverse(); - let rooms = rooms + let rooms: Vec<_> = rooms .into_iter() .skip(page.saturating_sub(1).saturating_mul(PAGE_SIZE)) .take(PAGE_SIZE) - .collect::>(); + .collect(); if rooms.is_empty() { return Ok(RoomMessageEventContent::text_plain("No more rooms.")); diff --git a/src/admin/room/info.rs b/src/admin/room/info.rs index d17a29247..13a74a9d3 100644 --- a/src/admin/room/info.rs +++ b/src/admin/room/info.rs @@ -1,5 +1,6 @@ use clap::Subcommand; -use conduit::Result; +use conduit::{utils::ReadyExt, Result}; +use futures::StreamExt; use ruma::{events::room::message::RoomMessageEventContent, RoomId}; use crate::{admin_command, admin_command_dispatch}; @@ -32,46 +33,40 @@ async fn list_joined_members(&self, room_id: Box, local_only: bool) -> R .rooms .state_accessor .get_name(&room_id) - .ok() - .flatten() - .unwrap_or_else(|| room_id.to_string()); + .await + .unwrap_or_else(|_| room_id.to_string()); - let members = self + let member_info: Vec<_> = self .services .rooms .state_cache .room_members(&room_id) - .filter_map(|member| { - if local_only { - member - .ok() - .filter(|user| self.services.globals.user_is_local(user)) - } else { - member.ok() - } - }); - - let member_info = members - .into_iter() - .map(|user_id| { - ( - user_id.clone(), + .ready_filter(|user_id| { + local_only + .then(|| self.services.globals.user_is_local(user_id)) + .unwrap_or(true) + }) + .map(ToOwned::to_owned) + .filter_map(|user_id| async move { + Some(( self.services .users .displayname(&user_id) - .unwrap_or(None) - .unwrap_or_else(|| user_id.to_string()), - ) + .await + .unwrap_or_else(|_| user_id.to_string()), + user_id, + )) }) - .collect::>(); + .collect() + .await; let output_plain = format!( "{} Members in Room \"{}\":\n```\n{}\n```", member_info.len(), room_name, member_info - .iter() - .map(|(mxid, displayname)| format!("{mxid} | {displayname}")) + .into_iter() + .map(|(displayname, mxid)| format!("{mxid} | {displayname}")) .collect::>() .join("\n") ); @@ -81,11 +76,12 @@ async fn list_joined_members(&self, room_id: Box, local_only: bool) -> R #[admin_command] async fn view_room_topic(&self, room_id: Box) -> Result { - let Some(room_topic) = self + let Ok(room_topic) = self .services .rooms .state_accessor - .get_room_topic(&room_id)? + .get_room_topic(&room_id) + .await else { return Ok(RoomMessageEventContent::text_plain("Room does not have a room topic set.")); }; diff --git a/src/admin/room/mod.rs b/src/admin/room/mod.rs index 64d2af452..8c6cbeaae 100644 --- a/src/admin/room/mod.rs +++ b/src/admin/room/mod.rs @@ -6,6 +6,7 @@ mod moderation; use clap::Subcommand; use conduit::Result; +use ruma::OwnedRoomId; use self::{ alias::RoomAliasCommand, directory::RoomDirectoryCommand, info::RoomInfoCommand, moderation::RoomModerationCommand, @@ -49,4 +50,9 @@ pub(super) enum RoomCommand { #[command(subcommand)] /// - Manage the room directory Directory(RoomDirectoryCommand), + + /// - Check if we know about a room + Exists { + room_id: OwnedRoomId, + }, } diff --git a/src/admin/room/moderation.rs b/src/admin/room/moderation.rs index 70d8486b4..cfc048bdd 100644 --- a/src/admin/room/moderation.rs +++ b/src/admin/room/moderation.rs @@ -1,6 +1,11 @@ use api::client::leave_room; use clap::Subcommand; -use conduit::{debug, error, info, warn, Result}; +use conduit::{ + debug, error, info, + utils::{IterStream, ReadyExt}, + warn, Result, +}; +use futures::StreamExt; use ruma::{events::room::message::RoomMessageEventContent, OwnedRoomId, RoomAliasId, RoomId, RoomOrAliasId}; use crate::{admin_command, admin_command_dispatch, get_room_info}; @@ -76,7 +81,7 @@ async fn ban_room( let admin_room_alias = &self.services.globals.admin_alias; - if let Some(admin_room_id) = self.services.admin.get_admin_room()? { + if let Ok(admin_room_id) = self.services.admin.get_admin_room().await { if room.to_string().eq(&admin_room_id) || room.to_string().eq(admin_room_alias) { return Ok(RoomMessageEventContent::text_plain("Not allowed to ban the admin room.")); } @@ -95,7 +100,7 @@ async fn ban_room( debug!("Room specified is a room ID, banning room ID"); - self.services.rooms.metadata.ban_room(&room_id, true)?; + self.services.rooms.metadata.ban_room(&room_id, true); room_id } else if room.is_room_alias_id() { @@ -114,7 +119,13 @@ async fn ban_room( get_alias_helper to fetch room ID remotely" ); - let room_id = if let Some(room_id) = self.services.rooms.alias.resolve_local_alias(&room_alias)? { + let room_id = if let Ok(room_id) = self + .services + .rooms + .alias + .resolve_local_alias(&room_alias) + .await + { room_id } else { debug!("We don't have this room alias to a room ID locally, attempting to fetch room ID over federation"); @@ -138,7 +149,7 @@ async fn ban_room( } }; - self.services.rooms.metadata.ban_room(&room_id, true)?; + self.services.rooms.metadata.ban_room(&room_id, true); room_id } else { @@ -150,56 +161,40 @@ async fn ban_room( debug!("Making all users leave the room {}", &room); if force { - for local_user in self + let mut users = self .services .rooms .state_cache .room_members(&room_id) - .filter_map(|user| { - user.ok().filter(|local_user| { - self.services.globals.user_is_local(local_user) - // additional wrapped check here is to avoid adding remote users - // who are in the admin room to the list of local users (would - // fail auth check) - && (self.services.globals.user_is_local(local_user) - // since this is a force operation, assume user is an admin - // if somehow this fails - && self.services - .users - .is_admin(local_user) - .unwrap_or(true)) - }) - }) { + .ready_filter(|user| self.services.globals.user_is_local(user)) + .boxed(); + + while let Some(local_user) = users.next().await { debug!( - "Attempting leave for user {} in room {} (forced, ignoring all errors, evicting admins too)", - &local_user, &room_id + "Attempting leave for user {local_user} in room {room_id} (forced, ignoring all errors, evicting \ + admins too)", ); - if let Err(e) = leave_room(self.services, &local_user, &room_id, None).await { + if let Err(e) = leave_room(self.services, local_user, &room_id, None).await { warn!(%e, "Failed to leave room"); } } } else { - for local_user in self + let mut users = self .services .rooms .state_cache .room_members(&room_id) - .filter_map(|user| { - user.ok().filter(|local_user| { - local_user.server_name() == self.services.globals.server_name() - // additional wrapped check here is to avoid adding remote users - // who are in the admin room to the list of local users (would fail auth check) - && (local_user.server_name() - == self.services.globals.server_name() - && !self.services - .users - .is_admin(local_user) - .unwrap_or(false)) - }) - }) { + .ready_filter(|user| self.services.globals.user_is_local(user)) + .boxed(); + + while let Some(local_user) = users.next().await { + if self.services.users.is_admin(local_user).await { + continue; + } + debug!("Attempting leave for user {} in room {}", &local_user, &room_id); - if let Err(e) = leave_room(self.services, &local_user, &room_id, None).await { + if let Err(e) = leave_room(self.services, local_user, &room_id, None).await { error!( "Error attempting to make local user {} leave room {} during room banning: {}", &local_user, &room_id, e @@ -214,12 +209,14 @@ async fn ban_room( } // remove any local aliases, ignore errors - for ref local_alias in self + for local_alias in &self .services .rooms .alias .local_aliases_for_room(&room_id) - .filter_map(Result::ok) + .map(ToOwned::to_owned) + .collect::>() + .await { _ = self .services @@ -230,10 +227,10 @@ async fn ban_room( } // unpublish from room directory, ignore errors - _ = self.services.rooms.directory.set_not_public(&room_id); + self.services.rooms.directory.set_not_public(&room_id); if disable_federation { - self.services.rooms.metadata.disable_room(&room_id, true)?; + self.services.rooms.metadata.disable_room(&room_id, true); return Ok(RoomMessageEventContent::text_plain( "Room banned, removed all our local users, and disabled incoming federation with room.", )); @@ -268,7 +265,7 @@ async fn ban_list_of_rooms(&self, force: bool, disable_federation: bool) -> Resu for &room in &rooms_s { match <&RoomOrAliasId>::try_from(room) { Ok(room_alias_or_id) => { - if let Some(admin_room_id) = self.services.admin.get_admin_room()? { + if let Ok(admin_room_id) = self.services.admin.get_admin_room().await { if room.to_owned().eq(&admin_room_id) || room.to_owned().eq(admin_room_alias) { info!("User specified admin room in bulk ban list, ignoring"); continue; @@ -300,43 +297,48 @@ async fn ban_list_of_rooms(&self, force: bool, disable_federation: bool) -> Resu if room_alias_or_id.is_room_alias_id() { match RoomAliasId::parse(room_alias_or_id) { Ok(room_alias) => { - let room_id = - if let Some(room_id) = self.services.rooms.alias.resolve_local_alias(&room_alias)? { - room_id - } else { - debug!( - "We don't have this room alias to a room ID locally, attempting to fetch room \ - ID over federation" - ); - - match self - .services - .rooms - .alias - .resolve_alias(&room_alias, None) - .await - { - Ok((room_id, servers)) => { - debug!( - ?room_id, - ?servers, - "Got federation response fetching room ID for {room}", - ); - room_id - }, - Err(e) => { - // don't fail if force blocking - if force { - warn!("Failed to resolve room alias {room} to a room ID: {e}"); - continue; - } - - return Ok(RoomMessageEventContent::text_plain(format!( - "Failed to resolve room alias {room} to a room ID: {e}" - ))); - }, - } - }; + let room_id = if let Ok(room_id) = self + .services + .rooms + .alias + .resolve_local_alias(&room_alias) + .await + { + room_id + } else { + debug!( + "We don't have this room alias to a room ID locally, attempting to fetch room ID \ + over federation" + ); + + match self + .services + .rooms + .alias + .resolve_alias(&room_alias, None) + .await + { + Ok((room_id, servers)) => { + debug!( + ?room_id, + ?servers, + "Got federation response fetching room ID for {room}", + ); + room_id + }, + Err(e) => { + // don't fail if force blocking + if force { + warn!("Failed to resolve room alias {room} to a room ID: {e}"); + continue; + } + + return Ok(RoomMessageEventContent::text_plain(format!( + "Failed to resolve room alias {room} to a room ID: {e}" + ))); + }, + } + }; room_ids.push(room_id); }, @@ -374,74 +376,52 @@ async fn ban_list_of_rooms(&self, force: bool, disable_federation: bool) -> Resu } for room_id in room_ids { - if self - .services - .rooms - .metadata - .ban_room(&room_id, true) - .is_ok() - { - debug!("Banned {room_id} successfully"); - room_ban_count = room_ban_count.saturating_add(1); - } + self.services.rooms.metadata.ban_room(&room_id, true); + + debug!("Banned {room_id} successfully"); + room_ban_count = room_ban_count.saturating_add(1); debug!("Making all users leave the room {}", &room_id); if force { - for local_user in self + let mut users = self .services .rooms .state_cache .room_members(&room_id) - .filter_map(|user| { - user.ok().filter(|local_user| { - local_user.server_name() == self.services.globals.server_name() - // additional wrapped check here is to avoid adding remote - // users who are in the admin room to the list of local - // users (would fail auth check) - && (local_user.server_name() - == self.services.globals.server_name() - // since this is a force operation, assume user is an - // admin if somehow this fails - && self.services - .users - .is_admin(local_user) - .unwrap_or(true)) - }) - }) { + .ready_filter(|user| self.services.globals.user_is_local(user)) + .boxed(); + + while let Some(local_user) = users.next().await { debug!( - "Attempting leave for user {} in room {} (forced, ignoring all errors, evicting admins too)", - &local_user, room_id + "Attempting leave for user {local_user} in room {room_id} (forced, ignoring all errors, evicting \ + admins too)", ); - if let Err(e) = leave_room(self.services, &local_user, &room_id, None).await { + + if let Err(e) = leave_room(self.services, local_user, &room_id, None).await { warn!(%e, "Failed to leave room"); } } } else { - for local_user in self + let mut users = self .services .rooms .state_cache .room_members(&room_id) - .filter_map(|user| { - user.ok().filter(|local_user| { - local_user.server_name() == self.services.globals.server_name() - // additional wrapped check here is to avoid adding remote - // users who are in the admin room to the list of local - // users (would fail auth check) - && (local_user.server_name() - == self.services.globals.server_name() - && !self.services - .users - .is_admin(local_user) - .unwrap_or(false)) - }) - }) { - debug!("Attempting leave for user {} in room {}", &local_user, &room_id); - if let Err(e) = leave_room(self.services, &local_user, &room_id, None).await { + .ready_filter(|user| self.services.globals.user_is_local(user)) + .boxed(); + + while let Some(local_user) = users.next().await { + if self.services.users.is_admin(local_user).await { + continue; + } + + debug!("Attempting leave for user {local_user} in room {room_id}"); + if let Err(e) = leave_room(self.services, local_user, &room_id, None).await { error!( - "Error attempting to make local user {} leave room {} during bulk room banning: {}", - &local_user, &room_id, e + "Error attempting to make local user {local_user} leave room {room_id} during bulk room \ + banning: {e}", ); + return Ok(RoomMessageEventContent::text_plain(format!( "Error attempting to make local user {} leave room {} during room banning (room is still \ banned but not removing any more users and not banning any more rooms): {}\nIf you would \ @@ -453,26 +433,26 @@ async fn ban_list_of_rooms(&self, force: bool, disable_federation: bool) -> Resu } // remove any local aliases, ignore errors - for ref local_alias in self - .services + self.services .rooms .alias .local_aliases_for_room(&room_id) - .filter_map(Result::ok) - { - _ = self - .services - .rooms - .alias - .remove_alias(local_alias, &self.services.globals.server_user) - .await; - } + .map(ToOwned::to_owned) + .for_each(|local_alias| async move { + self.services + .rooms + .alias + .remove_alias(&local_alias, &self.services.globals.server_user) + .await + .ok(); + }) + .await; // unpublish from room directory, ignore errors - _ = self.services.rooms.directory.set_not_public(&room_id); + self.services.rooms.directory.set_not_public(&room_id); if disable_federation { - self.services.rooms.metadata.disable_room(&room_id, true)?; + self.services.rooms.metadata.disable_room(&room_id, true); } } @@ -503,7 +483,7 @@ async fn unban_room(&self, enable_federation: bool, room: Box) -> debug!("Room specified is a room ID, unbanning room ID"); - self.services.rooms.metadata.ban_room(&room_id, false)?; + self.services.rooms.metadata.ban_room(&room_id, false); room_id } else if room.is_room_alias_id() { @@ -522,7 +502,13 @@ async fn unban_room(&self, enable_federation: bool, room: Box) -> get_alias_helper to fetch room ID remotely" ); - let room_id = if let Some(room_id) = self.services.rooms.alias.resolve_local_alias(&room_alias)? { + let room_id = if let Ok(room_id) = self + .services + .rooms + .alias + .resolve_local_alias(&room_alias) + .await + { room_id } else { debug!("We don't have this room alias to a room ID locally, attempting to fetch room ID over federation"); @@ -546,7 +532,7 @@ async fn unban_room(&self, enable_federation: bool, room: Box) -> } }; - self.services.rooms.metadata.ban_room(&room_id, false)?; + self.services.rooms.metadata.ban_room(&room_id, false); room_id } else { @@ -557,7 +543,7 @@ async fn unban_room(&self, enable_federation: bool, room: Box) -> }; if enable_federation { - self.services.rooms.metadata.disable_room(&room_id, false)?; + self.services.rooms.metadata.disable_room(&room_id, false); return Ok(RoomMessageEventContent::text_plain("Room unbanned.")); } @@ -569,45 +555,42 @@ async fn unban_room(&self, enable_federation: bool, room: Box) -> #[admin_command] async fn list_banned_rooms(&self, no_details: bool) -> Result { - let rooms = self + let room_ids: Vec = self .services .rooms .metadata .list_banned_rooms() - .collect::, _>>(); + .map(Into::into) + .collect() + .await; - match rooms { - Ok(room_ids) => { - if room_ids.is_empty() { - return Ok(RoomMessageEventContent::text_plain("No rooms are banned.")); - } - - let mut rooms = room_ids - .into_iter() - .map(|room_id| get_room_info(self.services, &room_id)) - .collect::>(); - rooms.sort_by_key(|r| r.1); - rooms.reverse(); - - let output_plain = format!( - "Rooms Banned ({}):\n```\n{}\n```", - rooms.len(), - rooms - .iter() - .map(|(id, members, name)| if no_details { - format!("{id}") - } else { - format!("{id}\tMembers: {members}\tName: {name}") - }) - .collect::>() - .join("\n") - ); - - Ok(RoomMessageEventContent::notice_markdown(output_plain)) - }, - Err(e) => { - error!("Failed to list banned rooms: {e}"); - Ok(RoomMessageEventContent::text_plain(format!("Unable to list banned rooms: {e}"))) - }, + if room_ids.is_empty() { + return Ok(RoomMessageEventContent::text_plain("No rooms are banned.")); } + + let mut rooms = room_ids + .iter() + .stream() + .then(|room_id| get_room_info(self.services, room_id)) + .collect::>() + .await; + + rooms.sort_by_key(|r| r.1); + rooms.reverse(); + + let output_plain = format!( + "Rooms Banned ({}):\n```\n{}\n```", + rooms.len(), + rooms + .iter() + .map(|(id, members, name)| if no_details { + format!("{id}") + } else { + format!("{id}\tMembers: {members}\tName: {name}") + }) + .collect::>() + .join("\n") + ); + + Ok(RoomMessageEventContent::notice_markdown(output_plain)) } diff --git a/src/admin/user/commands.rs b/src/admin/user/commands.rs index 20691f1a2..1b086856a 100644 --- a/src/admin/user/commands.rs +++ b/src/admin/user/commands.rs @@ -1,7 +1,9 @@ use std::{collections::BTreeMap, fmt::Write as _}; use api::client::{full_user_deactivate, join_room_by_id_helper, leave_room}; -use conduit::{error, info, utils, warn, PduBuilder, Result}; +use conduit::{error, info, is_equal_to, utils, warn, PduBuilder, Result}; +use conduit_api::client::{leave_all_rooms, update_avatar_url, update_displayname}; +use futures::StreamExt; use ruma::{ events::{ room::{ @@ -25,16 +27,19 @@ const AUTO_GEN_PASSWORD_LENGTH: usize = 25; #[admin_command] pub(super) async fn list_users(&self) -> Result { - match self.services.users.list_local_users() { - Ok(users) => { - let mut plain_msg = format!("Found {} local user account(s):\n```\n", users.len()); - plain_msg += users.join("\n").as_str(); - plain_msg += "\n```"; + let users = self + .services + .users + .list_local_users() + .map(ToString::to_string) + .collect::>() + .await; - Ok(RoomMessageEventContent::notice_markdown(plain_msg)) - }, - Err(e) => Ok(RoomMessageEventContent::text_plain(e.to_string())), - } + let mut plain_msg = format!("Found {} local user account(s):\n```\n", users.len()); + plain_msg += users.join("\n").as_str(); + plain_msg += "\n```"; + + Ok(RoomMessageEventContent::notice_markdown(plain_msg)) } #[admin_command] @@ -42,7 +47,7 @@ pub(super) async fn create_user(&self, username: String, password: Option )); } - self.services.users.deactivate_account(&user_id)?; + self.services.users.deactivate_account(&user_id).await?; if !no_leave_rooms { self.services @@ -175,17 +184,22 @@ pub(super) async fn deactivate(&self, no_leave_rooms: bool, user_id: String) -> .send_message(RoomMessageEventContent::text_plain(format!( "Making {user_id} leave all rooms after deactivation..." ))) - .await; + .await + .ok(); let all_joined_rooms: Vec = self .services .rooms .state_cache .rooms_joined(&user_id) - .filter_map(Result::ok) - .collect(); + .map(Into::into) + .collect() + .await; - full_user_deactivate(self.services, &user_id, all_joined_rooms).await?; + full_user_deactivate(self.services, &user_id, &all_joined_rooms).await?; + update_displayname(self.services, &user_id, None, &all_joined_rooms).await?; + update_avatar_url(self.services, &user_id, None, None, &all_joined_rooms).await?; + leave_all_rooms(self.services, &user_id).await; } Ok(RoomMessageEventContent::text_plain(format!( @@ -238,15 +252,16 @@ pub(super) async fn deactivate_all(&self, no_leave_rooms: bool, force: bool) -> let mut admins = Vec::new(); for username in usernames { - match parse_active_local_user_id(self.services, username) { + match parse_active_local_user_id(self.services, username).await { Ok(user_id) => { - if self.services.users.is_admin(&user_id)? && !force { + if self.services.users.is_admin(&user_id).await && !force { self.services .admin .send_message(RoomMessageEventContent::text_plain(format!( "{username} is an admin and --force is not set, skipping over" ))) - .await; + .await + .ok(); admins.push(username); continue; } @@ -258,7 +273,8 @@ pub(super) async fn deactivate_all(&self, no_leave_rooms: bool, force: bool) -> .send_message(RoomMessageEventContent::text_plain(format!( "{username} is the server service account, skipping over" ))) - .await; + .await + .ok(); continue; } @@ -270,7 +286,8 @@ pub(super) async fn deactivate_all(&self, no_leave_rooms: bool, force: bool) -> .send_message(RoomMessageEventContent::text_plain(format!( "{username} is not a valid username, skipping over: {e}" ))) - .await; + .await + .ok(); continue; }, } @@ -279,7 +296,7 @@ pub(super) async fn deactivate_all(&self, no_leave_rooms: bool, force: bool) -> let mut deactivation_count: usize = 0; for user_id in user_ids { - match self.services.users.deactivate_account(&user_id) { + match self.services.users.deactivate_account(&user_id).await { Ok(()) => { deactivation_count = deactivation_count.saturating_add(1); if !no_leave_rooms { @@ -289,16 +306,26 @@ pub(super) async fn deactivate_all(&self, no_leave_rooms: bool, force: bool) -> .rooms .state_cache .rooms_joined(&user_id) - .filter_map(Result::ok) - .collect(); - full_user_deactivate(self.services, &user_id, all_joined_rooms).await?; + .map(Into::into) + .collect() + .await; + + full_user_deactivate(self.services, &user_id, &all_joined_rooms).await?; + update_displayname(self.services, &user_id, None, &all_joined_rooms) + .await + .ok(); + update_avatar_url(self.services, &user_id, None, None, &all_joined_rooms) + .await + .ok(); + leave_all_rooms(self.services, &user_id).await; } }, Err(e) => { self.services .admin .send_message(RoomMessageEventContent::text_plain(format!("Failed deactivating user: {e}"))) - .await; + .await + .ok(); }, } } @@ -326,9 +353,9 @@ pub(super) async fn list_joined_rooms(&self, user_id: String) -> Result(&room_id, &StateEventType::RoomPowerLevels, "") + .await + .ok(); let user_can_demote_self = room_power_levels .as_ref() @@ -417,9 +443,9 @@ pub(super) async fn force_demote( .services .rooms .state_accessor - .room_state_get(&room_id, &StateEventType::RoomCreate, "")? - .as_ref() - .is_some_and(|event| event.sender == user_id); + .room_state_get(&room_id, &StateEventType::RoomCreate, "") + .await + .is_ok_and(|event| event.sender == user_id); if !user_can_demote_self { return Ok(RoomMessageEventContent::notice_markdown( @@ -473,15 +499,16 @@ pub(super) async fn make_user_admin(&self, user_id: String) -> Result, tag: String, ) -> Result { - let user_id = parse_active_local_user_id(self.services, &user_id)?; + let user_id = parse_active_local_user_id(self.services, &user_id).await?; let event = self .services .account_data - .get(Some(&room_id), &user_id, RoomAccountDataEventType::Tag)?; + .get(Some(&room_id), &user_id, RoomAccountDataEventType::Tag) + .await; let mut tags_event = event.map_or_else( - || TagEvent { + |_| TagEvent { content: TagEventContent { tags: BTreeMap::new(), }, @@ -494,12 +521,15 @@ pub(super) async fn put_room_tag( .tags .insert(tag.clone().into(), TagInfo::new()); - self.services.account_data.update( - Some(&room_id), - &user_id, - RoomAccountDataEventType::Tag, - &serde_json::to_value(tags_event).expect("to json value always works"), - )?; + self.services + .account_data + .update( + Some(&room_id), + &user_id, + RoomAccountDataEventType::Tag, + &serde_json::to_value(tags_event).expect("to json value always works"), + ) + .await?; Ok(RoomMessageEventContent::text_plain(format!( "Successfully updated room account data for {user_id} and room {room_id} with tag {tag}" @@ -510,15 +540,16 @@ pub(super) async fn put_room_tag( pub(super) async fn delete_room_tag( &self, user_id: String, room_id: Box, tag: String, ) -> Result { - let user_id = parse_active_local_user_id(self.services, &user_id)?; + let user_id = parse_active_local_user_id(self.services, &user_id).await?; let event = self .services .account_data - .get(Some(&room_id), &user_id, RoomAccountDataEventType::Tag)?; + .get(Some(&room_id), &user_id, RoomAccountDataEventType::Tag) + .await; let mut tags_event = event.map_or_else( - || TagEvent { + |_| TagEvent { content: TagEventContent { tags: BTreeMap::new(), }, @@ -528,12 +559,15 @@ pub(super) async fn delete_room_tag( tags_event.content.tags.remove(&tag.clone().into()); - self.services.account_data.update( - Some(&room_id), - &user_id, - RoomAccountDataEventType::Tag, - &serde_json::to_value(tags_event).expect("to json value always works"), - )?; + self.services + .account_data + .update( + Some(&room_id), + &user_id, + RoomAccountDataEventType::Tag, + &serde_json::to_value(tags_event).expect("to json value always works"), + ) + .await?; Ok(RoomMessageEventContent::text_plain(format!( "Successfully updated room account data for {user_id} and room {room_id}, deleting room tag {tag}" @@ -542,15 +576,16 @@ pub(super) async fn delete_room_tag( #[admin_command] pub(super) async fn get_room_tags(&self, user_id: String, room_id: Box) -> Result { - let user_id = parse_active_local_user_id(self.services, &user_id)?; + let user_id = parse_active_local_user_id(self.services, &user_id).await?; let event = self .services .account_data - .get(Some(&room_id), &user_id, RoomAccountDataEventType::Tag)?; + .get(Some(&room_id), &user_id, RoomAccountDataEventType::Tag) + .await; let tags_event = event.map_or_else( - || TagEvent { + |_| TagEvent { content: TagEventContent { tags: BTreeMap::new(), }, @@ -566,11 +601,12 @@ pub(super) async fn get_room_tags(&self, user_id: String, room_id: Box) #[admin_command] pub(super) async fn redact_event(&self, event_id: Box) -> Result { - let Some(event) = self + let Ok(event) = self .services .rooms .timeline - .get_non_outlier_pdu(&event_id)? + .get_non_outlier_pdu(&event_id) + .await else { return Ok(RoomMessageEventContent::text_plain("Event does not exist in our database.")); }; diff --git a/src/admin/utils.rs b/src/admin/utils.rs index 8d3d15ae4..ba98bbeac 100644 --- a/src/admin/utils.rs +++ b/src/admin/utils.rs @@ -8,23 +8,21 @@ pub(crate) fn escape_html(s: &str) -> String { .replace('>', ">") } -pub(crate) fn get_room_info(services: &Services, id: &RoomId) -> (OwnedRoomId, u64, String) { +pub(crate) async fn get_room_info(services: &Services, room_id: &RoomId) -> (OwnedRoomId, u64, String) { ( - id.into(), + room_id.into(), services .rooms .state_cache - .room_joined_count(id) - .ok() - .flatten() + .room_joined_count(room_id) + .await .unwrap_or(0), services .rooms .state_accessor - .get_name(id) - .ok() - .flatten() - .unwrap_or_else(|| id.to_string()), + .get_name(room_id) + .await + .unwrap_or_else(|_| room_id.to_string()), ) } @@ -46,14 +44,14 @@ pub(crate) fn parse_local_user_id(services: &Services, user_id: &str) -> Result< } /// Parses user ID that is an active (not guest or deactivated) local user -pub(crate) fn parse_active_local_user_id(services: &Services, user_id: &str) -> Result { +pub(crate) async fn parse_active_local_user_id(services: &Services, user_id: &str) -> Result { let user_id = parse_local_user_id(services, user_id)?; - if !services.users.exists(&user_id)? { + if !services.users.exists(&user_id).await { return Err!("User {user_id:?} does not exist on this server."); } - if services.users.is_deactivated(&user_id)? { + if services.users.is_deactivated(&user_id).await? { return Err!("User {user_id:?} is deactivated."); } diff --git a/src/api/Cargo.toml b/src/api/Cargo.toml index 2b89c3e82..6e37cb407 100644 --- a/src/api/Cargo.toml +++ b/src/api/Cargo.toml @@ -45,7 +45,7 @@ conduit-core.workspace = true conduit-database.workspace = true conduit-service.workspace = true const-str.workspace = true -futures-util.workspace = true +futures.workspace = true hmac.workspace = true http.workspace = true http-body-util.workspace = true diff --git a/src/api/client/account.rs b/src/api/client/account.rs index cee86f80a..63d02f8f8 100644 --- a/src/api/client/account.rs +++ b/src/api/client/account.rs @@ -2,7 +2,8 @@ use std::fmt::Write; use axum::extract::State; use axum_client_ip::InsecureClientIp; -use conduit::{debug_info, error, info, utils, warn, Error, PduBuilder, Result}; +use conduit::{debug_info, error, info, is_equal_to, utils, utils::ReadyExt, warn, Error, PduBuilder, Result}; +use futures::{FutureExt, StreamExt}; use register::RegistrationKind; use ruma::{ api::client::{ @@ -55,7 +56,7 @@ pub(crate) async fn get_register_available_route( .ok_or(Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid."))?; // Check if username is creative enough - if services.users.exists(&user_id)? { + if services.users.exists(&user_id).await { return Err(Error::BadRequest(ErrorKind::UserInUse, "Desired user ID is already taken.")); } @@ -125,7 +126,7 @@ pub(crate) async fn register_route( // forbid guests from registering if there is not a real admin user yet. give // generic user error. - if is_guest && services.users.count()? < 2 { + if is_guest && services.users.count().await < 2 { warn!( "Guest account attempted to register before a real admin user has been registered, rejecting \ registration. Guest's initial device name: {:?}", @@ -142,7 +143,7 @@ pub(crate) async fn register_route( .filter(|user_id| !user_id.is_historical() && services.globals.user_is_local(user_id)) .ok_or(Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid."))?; - if services.users.exists(&proposed_user_id)? { + if services.users.exists(&proposed_user_id).await { return Err(Error::BadRequest(ErrorKind::UserInUse, "Desired user ID is already taken.")); } @@ -162,7 +163,7 @@ pub(crate) async fn register_route( services.globals.server_name(), ) .unwrap(); - if !services.users.exists(&proposed_user_id)? { + if !services.users.exists(&proposed_user_id).await { break proposed_user_id; } }, @@ -210,12 +211,15 @@ pub(crate) async fn register_route( if !skip_auth { if let Some(auth) = &body.auth { - let (worked, uiaainfo) = services.uiaa.try_auth( - &UserId::parse_with_server_name("", services.globals.server_name()).expect("we know this is valid"), - "".into(), - auth, - &uiaainfo, - )?; + let (worked, uiaainfo) = services + .uiaa + .try_auth( + &UserId::parse_with_server_name("", services.globals.server_name()).expect("we know this is valid"), + "".into(), + auth, + &uiaainfo, + ) + .await?; if !worked { return Err(Error::Uiaa(uiaainfo)); } @@ -227,7 +231,7 @@ pub(crate) async fn register_route( "".into(), &uiaainfo, &json, - )?; + ); return Err(Error::Uiaa(uiaainfo)); } else { return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); @@ -255,21 +259,23 @@ pub(crate) async fn register_route( services .users - .set_displayname(&user_id, Some(displayname.clone())) - .await?; + .set_displayname(&user_id, Some(displayname.clone())); // Initial account data - services.account_data.update( - None, - &user_id, - GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(ruma::events::push_rules::PushRulesEvent { - content: ruma::events::push_rules::PushRulesEventContent { - global: push::Ruleset::server_default(&user_id), - }, - }) - .expect("to json always works"), - )?; + services + .account_data + .update( + None, + &user_id, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(ruma::events::push_rules::PushRulesEvent { + content: ruma::events::push_rules::PushRulesEventContent { + global: push::Ruleset::server_default(&user_id), + }, + }) + .expect("to json always works"), + ) + .await?; // Inhibit login does not work for guests if !is_guest && body.inhibit_login { @@ -294,13 +300,16 @@ pub(crate) async fn register_route( let token = utils::random_string(TOKEN_LENGTH); // Create device for this account - services.users.create_device( - &user_id, - &device_id, - &token, - body.initial_device_display_name.clone(), - Some(client.to_string()), - )?; + services + .users + .create_device( + &user_id, + &device_id, + &token, + body.initial_device_display_name.clone(), + Some(client.to_string()), + ) + .await?; debug_info!(%user_id, %device_id, "User account was created"); @@ -318,7 +327,8 @@ pub(crate) async fn register_route( "New user \"{user_id}\" registered on this server from IP {client} and device display name \ \"{device_display_name}\"" ))) - .await; + .await + .ok(); } } else { info!("New user \"{user_id}\" registered on this server."); @@ -329,7 +339,8 @@ pub(crate) async fn register_route( .send_message(RoomMessageEventContent::notice_plain(format!( "New user \"{user_id}\" registered on this server from IP {client}" ))) - .await; + .await + .ok(); } } } @@ -346,7 +357,8 @@ pub(crate) async fn register_route( "Guest user \"{user_id}\" with device display name \"{device_display_name}\" registered on \ this server from IP {client}" ))) - .await; + .await + .ok(); } } else { #[allow(clippy::collapsible_else_if)] @@ -357,7 +369,8 @@ pub(crate) async fn register_route( "Guest user \"{user_id}\" with no device display name registered on this server from IP \ {client}", ))) - .await; + .await + .ok(); } } } @@ -365,10 +378,15 @@ pub(crate) async fn register_route( // If this is the first real user, grant them admin privileges except for guest // users Note: the server user, @conduit:servername, is generated first if !is_guest { - if let Some(admin_room) = services.admin.get_admin_room()? { - if services.rooms.state_cache.room_joined_count(&admin_room)? == Some(1) { + if let Ok(admin_room) = services.admin.get_admin_room().await { + if services + .rooms + .state_cache + .room_joined_count(&admin_room) + .await + .is_ok_and(is_equal_to!(1)) + { services.admin.make_user_admin(&user_id).await?; - warn!("Granting {user_id} admin privileges as the first user"); } } @@ -382,7 +400,8 @@ pub(crate) async fn register_route( if !services .rooms .state_cache - .server_in_room(services.globals.server_name(), room)? + .server_in_room(services.globals.server_name(), room) + .await { warn!("Skipping room {room} to automatically join as we have never joined before."); continue; @@ -398,6 +417,7 @@ pub(crate) async fn register_route( None, &body.appservice_info, ) + .boxed() .await { // don't return this error so we don't fail registrations @@ -461,16 +481,20 @@ pub(crate) async fn change_password_route( if let Some(auth) = &body.auth { let (worked, uiaainfo) = services .uiaa - .try_auth(sender_user, sender_device, auth, &uiaainfo)?; + .try_auth(sender_user, sender_device, auth, &uiaainfo) + .await?; + if !worked { return Err(Error::Uiaa(uiaainfo)); } - // Success! + + // Success! } else if let Some(json) = body.json_body { uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); services .uiaa - .create(sender_user, sender_device, &uiaainfo, &json)?; + .create(sender_user, sender_device, &uiaainfo, &json); + return Err(Error::Uiaa(uiaainfo)); } else { return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); @@ -482,14 +506,12 @@ pub(crate) async fn change_password_route( if body.logout_devices { // Logout all devices except the current one - for id in services + services .users .all_device_ids(sender_user) - .filter_map(Result::ok) - .filter(|id| id != sender_device) - { - services.users.remove_device(sender_user, &id)?; - } + .ready_filter(|id| id != sender_device) + .for_each(|id| services.users.remove_device(sender_user, id)) + .await; } info!("User {sender_user} changed their password."); @@ -500,7 +522,8 @@ pub(crate) async fn change_password_route( .send_message(RoomMessageEventContent::notice_plain(format!( "User {sender_user} changed their password." ))) - .await; + .await + .ok(); } Ok(change_password::v3::Response {}) @@ -520,7 +543,7 @@ pub(crate) async fn whoami_route( Ok(whoami::v3::Response { user_id: sender_user.clone(), device_id, - is_guest: services.users.is_deactivated(sender_user)? && body.appservice_info.is_none(), + is_guest: services.users.is_deactivated(sender_user).await? && body.appservice_info.is_none(), }) } @@ -561,7 +584,9 @@ pub(crate) async fn deactivate_route( if let Some(auth) = &body.auth { let (worked, uiaainfo) = services .uiaa - .try_auth(sender_user, sender_device, auth, &uiaainfo)?; + .try_auth(sender_user, sender_device, auth, &uiaainfo) + .await?; + if !worked { return Err(Error::Uiaa(uiaainfo)); } @@ -570,7 +595,8 @@ pub(crate) async fn deactivate_route( uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); services .uiaa - .create(sender_user, sender_device, &uiaainfo, &json)?; + .create(sender_user, sender_device, &uiaainfo, &json); + return Err(Error::Uiaa(uiaainfo)); } else { return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); @@ -581,10 +607,14 @@ pub(crate) async fn deactivate_route( .rooms .state_cache .rooms_joined(sender_user) - .filter_map(Result::ok) - .collect(); + .map(Into::into) + .collect() + .await; - full_user_deactivate(&services, sender_user, all_joined_rooms).await?; + super::update_displayname(&services, sender_user, None, &all_joined_rooms).await?; + super::update_avatar_url(&services, sender_user, None, None, &all_joined_rooms).await?; + + full_user_deactivate(&services, sender_user, &all_joined_rooms).await?; info!("User {sender_user} deactivated their account."); @@ -594,7 +624,8 @@ pub(crate) async fn deactivate_route( .send_message(RoomMessageEventContent::notice_plain(format!( "User {sender_user} deactivated their account." ))) - .await; + .await + .ok(); } Ok(deactivate::v3::Response { @@ -674,34 +705,27 @@ pub(crate) async fn check_registration_token_validity( /// - Removing all profile data /// - Leaving all rooms (and forgets all of them) pub async fn full_user_deactivate( - services: &Services, user_id: &UserId, all_joined_rooms: Vec, + services: &Services, user_id: &UserId, all_joined_rooms: &[OwnedRoomId], ) -> Result<()> { - services.users.deactivate_account(user_id)?; - - super::update_displayname(services, user_id, None, all_joined_rooms.clone()).await?; - super::update_avatar_url(services, user_id, None, None, all_joined_rooms.clone()).await?; + services.users.deactivate_account(user_id).await?; + super::update_displayname(services, user_id, None, all_joined_rooms).await?; + super::update_avatar_url(services, user_id, None, None, all_joined_rooms).await?; - let all_profile_keys = services + services .users .all_profile_keys(user_id) - .filter_map(Result::ok); - - for (profile_key, _profile_value) in all_profile_keys { - if let Err(e) = services.users.set_profile_key(user_id, &profile_key, None) { - warn!("Failed removing {user_id} profile key {profile_key}: {e}"); - } - } + .ready_for_each(|(profile_key, _)| services.users.set_profile_key(user_id, &profile_key, None)) + .await; for room_id in all_joined_rooms { - let state_lock = services.rooms.state.mutex.lock(&room_id).await; + let state_lock = services.rooms.state.mutex.lock(room_id).await; let room_power_levels = services .rooms .state_accessor - .room_state_get(&room_id, &StateEventType::RoomPowerLevels, "")? - .as_ref() - .and_then(|event| serde_json::from_str(event.content.get()).ok()?) - .and_then(|content: RoomPowerLevelsEventContent| content.into()); + .room_state_get_content::(room_id, &StateEventType::RoomPowerLevels, "") + .await + .ok(); let user_can_demote_self = room_power_levels .as_ref() @@ -710,9 +734,9 @@ pub async fn full_user_deactivate( }) || services .rooms .state_accessor - .room_state_get(&room_id, &StateEventType::RoomCreate, "")? - .as_ref() - .is_some_and(|event| event.sender == user_id); + .room_state_get(room_id, &StateEventType::RoomCreate, "") + .await + .is_ok_and(|event| event.sender == user_id); if user_can_demote_self { let mut power_levels_content = room_power_levels.unwrap_or_default(); @@ -732,7 +756,7 @@ pub async fn full_user_deactivate( timestamp: None, }, user_id, - &room_id, + room_id, &state_lock, ) .await diff --git a/src/api/client/alias.rs b/src/api/client/alias.rs index 12d6352c9..2399a3551 100644 --- a/src/api/client/alias.rs +++ b/src/api/client/alias.rs @@ -1,11 +1,9 @@ use axum::extract::State; -use conduit::{debug, Error, Result}; +use conduit::{debug, Err, Result}; +use futures::StreamExt; use rand::seq::SliceRandom; use ruma::{ - api::client::{ - alias::{create_alias, delete_alias, get_alias}, - error::ErrorKind, - }, + api::client::alias::{create_alias, delete_alias, get_alias}, OwnedServerName, RoomAliasId, RoomId, }; use service::Services; @@ -33,16 +31,17 @@ pub(crate) async fn create_alias_route( .forbidden_alias_names() .is_match(body.room_alias.alias()) { - return Err(Error::BadRequest(ErrorKind::forbidden(), "Room alias is forbidden.")); + return Err!(Request(Forbidden("Room alias is forbidden."))); } if services .rooms .alias - .resolve_local_alias(&body.room_alias)? - .is_some() + .resolve_local_alias(&body.room_alias) + .await + .is_ok() { - return Err(Error::Conflict("Alias already exists.")); + return Err!(Conflict("Alias already exists.")); } services @@ -95,16 +94,16 @@ pub(crate) async fn get_alias_route( .resolve_alias(&room_alias, servers.as_ref()) .await else { - return Err(Error::BadRequest(ErrorKind::NotFound, "Room with alias not found.")); + return Err!(Request(NotFound("Room with alias not found."))); }; - let servers = room_available_servers(&services, &room_id, &room_alias, &pre_servers); + let servers = room_available_servers(&services, &room_id, &room_alias, &pre_servers).await; debug!(?room_alias, ?room_id, "available servers: {servers:?}"); Ok(get_alias::v3::Response::new(room_id, servers)) } -fn room_available_servers( +async fn room_available_servers( services: &Services, room_id: &RoomId, room_alias: &RoomAliasId, pre_servers: &Option>, ) -> Vec { // find active servers in room state cache to suggest @@ -112,8 +111,9 @@ fn room_available_servers( .rooms .state_cache .room_servers(room_id) - .filter_map(Result::ok) - .collect(); + .map(ToOwned::to_owned) + .collect() + .await; // push any servers we want in the list already (e.g. responded remote alias // servers, room alias server itself) diff --git a/src/api/client/backup.rs b/src/api/client/backup.rs index 4ead87776..d52da80a2 100644 --- a/src/api/client/backup.rs +++ b/src/api/client/backup.rs @@ -1,18 +1,16 @@ use axum::extract::State; +use conduit::{err, Err}; use ruma::{ - api::client::{ - backup::{ - add_backup_keys, add_backup_keys_for_room, add_backup_keys_for_session, create_backup_version, - delete_backup_keys, delete_backup_keys_for_room, delete_backup_keys_for_session, delete_backup_version, - get_backup_info, get_backup_keys, get_backup_keys_for_room, get_backup_keys_for_session, - get_latest_backup_info, update_backup_version, - }, - error::ErrorKind, + api::client::backup::{ + add_backup_keys, add_backup_keys_for_room, add_backup_keys_for_session, create_backup_version, + delete_backup_keys, delete_backup_keys_for_room, delete_backup_keys_for_session, delete_backup_version, + get_backup_info, get_backup_keys, get_backup_keys_for_room, get_backup_keys_for_session, + get_latest_backup_info, update_backup_version, }, UInt, }; -use crate::{Error, Result, Ruma}; +use crate::{Result, Ruma}; /// # `POST /_matrix/client/r0/room_keys/version` /// @@ -40,7 +38,8 @@ pub(crate) async fn update_backup_version_route( let sender_user = body.sender_user.as_ref().expect("user is authenticated"); services .key_backups - .update_backup(sender_user, &body.version, &body.algorithm)?; + .update_backup(sender_user, &body.version, &body.algorithm) + .await?; Ok(update_backup_version::v3::Response {}) } @@ -55,14 +54,15 @@ pub(crate) async fn get_latest_backup_info_route( let (version, algorithm) = services .key_backups - .get_latest_backup(sender_user)? - .ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Key backup does not exist."))?; + .get_latest_backup(sender_user) + .await + .map_err(|_| err!(Request(NotFound("Key backup does not exist."))))?; Ok(get_latest_backup_info::v3::Response { algorithm, - count: (UInt::try_from(services.key_backups.count_keys(sender_user, &version)?) + count: (UInt::try_from(services.key_backups.count_keys(sender_user, &version).await) .expect("user backup keys count should not be that high")), - etag: services.key_backups.get_etag(sender_user, &version)?, + etag: services.key_backups.get_etag(sender_user, &version).await, version, }) } @@ -76,18 +76,21 @@ pub(crate) async fn get_backup_info_route( let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let algorithm = services .key_backups - .get_backup(sender_user, &body.version)? - .ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Key backup does not exist."))?; + .get_backup(sender_user, &body.version) + .await + .map_err(|_| err!(Request(NotFound("Key backup does not exist at version {:?}", body.version))))?; Ok(get_backup_info::v3::Response { algorithm, - count: (UInt::try_from( - services - .key_backups - .count_keys(sender_user, &body.version)?, - ) - .expect("user backup keys count should not be that high")), - etag: services.key_backups.get_etag(sender_user, &body.version)?, + count: services + .key_backups + .count_keys(sender_user, &body.version) + .await + .try_into()?, + etag: services + .key_backups + .get_etag(sender_user, &body.version) + .await, version: body.version.clone(), }) } @@ -105,7 +108,8 @@ pub(crate) async fn delete_backup_version_route( services .key_backups - .delete_backup(sender_user, &body.version)?; + .delete_backup(sender_user, &body.version) + .await; Ok(delete_backup_version::v3::Response {}) } @@ -123,34 +127,36 @@ pub(crate) async fn add_backup_keys_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if Some(&body.version) - != services - .key_backups - .get_latest_backup_version(sender_user)? - .as_ref() + if services + .key_backups + .get_latest_backup_version(sender_user) + .await + .is_ok_and(|version| version != body.version) { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "You may only manipulate the most recently created version of the backup.", - )); + return Err!(Request(InvalidParam( + "You may only manipulate the most recently created version of the backup." + ))); } for (room_id, room) in &body.rooms { for (session_id, key_data) in &room.sessions { services .key_backups - .add_key(sender_user, &body.version, room_id, session_id, key_data)?; + .add_key(sender_user, &body.version, room_id, session_id, key_data) + .await?; } } Ok(add_backup_keys::v3::Response { - count: (UInt::try_from( - services - .key_backups - .count_keys(sender_user, &body.version)?, - ) - .expect("user backup keys count should not be that high")), - etag: services.key_backups.get_etag(sender_user, &body.version)?, + count: services + .key_backups + .count_keys(sender_user, &body.version) + .await + .try_into()?, + etag: services + .key_backups + .get_etag(sender_user, &body.version) + .await, }) } @@ -167,32 +173,34 @@ pub(crate) async fn add_backup_keys_for_room_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if Some(&body.version) - != services - .key_backups - .get_latest_backup_version(sender_user)? - .as_ref() + if services + .key_backups + .get_latest_backup_version(sender_user) + .await + .is_ok_and(|version| version != body.version) { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "You may only manipulate the most recently created version of the backup.", - )); + return Err!(Request(InvalidParam( + "You may only manipulate the most recently created version of the backup." + ))); } for (session_id, key_data) in &body.sessions { services .key_backups - .add_key(sender_user, &body.version, &body.room_id, session_id, key_data)?; + .add_key(sender_user, &body.version, &body.room_id, session_id, key_data) + .await?; } Ok(add_backup_keys_for_room::v3::Response { - count: (UInt::try_from( - services - .key_backups - .count_keys(sender_user, &body.version)?, - ) - .expect("user backup keys count should not be that high")), - etag: services.key_backups.get_etag(sender_user, &body.version)?, + count: services + .key_backups + .count_keys(sender_user, &body.version) + .await + .try_into()?, + etag: services + .key_backups + .get_etag(sender_user, &body.version) + .await, }) } @@ -209,30 +217,32 @@ pub(crate) async fn add_backup_keys_for_session_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if Some(&body.version) - != services - .key_backups - .get_latest_backup_version(sender_user)? - .as_ref() + if services + .key_backups + .get_latest_backup_version(sender_user) + .await + .is_ok_and(|version| version != body.version) { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "You may only manipulate the most recently created version of the backup.", - )); + return Err!(Request(InvalidParam( + "You may only manipulate the most recently created version of the backup." + ))); } services .key_backups - .add_key(sender_user, &body.version, &body.room_id, &body.session_id, &body.session_data)?; + .add_key(sender_user, &body.version, &body.room_id, &body.session_id, &body.session_data) + .await?; Ok(add_backup_keys_for_session::v3::Response { - count: (UInt::try_from( - services - .key_backups - .count_keys(sender_user, &body.version)?, - ) - .expect("user backup keys count should not be that high")), - etag: services.key_backups.get_etag(sender_user, &body.version)?, + count: services + .key_backups + .count_keys(sender_user, &body.version) + .await + .try_into()?, + etag: services + .key_backups + .get_etag(sender_user, &body.version) + .await, }) } @@ -244,7 +254,10 @@ pub(crate) async fn get_backup_keys_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let rooms = services.key_backups.get_all(sender_user, &body.version)?; + let rooms = services + .key_backups + .get_all(sender_user, &body.version) + .await; Ok(get_backup_keys::v3::Response { rooms, @@ -261,7 +274,8 @@ pub(crate) async fn get_backup_keys_for_room_route( let sessions = services .key_backups - .get_room(sender_user, &body.version, &body.room_id)?; + .get_room(sender_user, &body.version, &body.room_id) + .await; Ok(get_backup_keys_for_room::v3::Response { sessions, @@ -278,8 +292,9 @@ pub(crate) async fn get_backup_keys_for_session_route( let key_data = services .key_backups - .get_session(sender_user, &body.version, &body.room_id, &body.session_id)? - .ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Backup key not found for this user's session."))?; + .get_session(sender_user, &body.version, &body.room_id, &body.session_id) + .await + .map_err(|_| err!(Request(NotFound(debug_error!("Backup key not found for this user's session.")))))?; Ok(get_backup_keys_for_session::v3::Response { key_data, @@ -296,16 +311,19 @@ pub(crate) async fn delete_backup_keys_route( services .key_backups - .delete_all_keys(sender_user, &body.version)?; + .delete_all_keys(sender_user, &body.version) + .await; Ok(delete_backup_keys::v3::Response { - count: (UInt::try_from( - services - .key_backups - .count_keys(sender_user, &body.version)?, - ) - .expect("user backup keys count should not be that high")), - etag: services.key_backups.get_etag(sender_user, &body.version)?, + count: services + .key_backups + .count_keys(sender_user, &body.version) + .await + .try_into()?, + etag: services + .key_backups + .get_etag(sender_user, &body.version) + .await, }) } @@ -319,16 +337,19 @@ pub(crate) async fn delete_backup_keys_for_room_route( services .key_backups - .delete_room_keys(sender_user, &body.version, &body.room_id)?; + .delete_room_keys(sender_user, &body.version, &body.room_id) + .await; Ok(delete_backup_keys_for_room::v3::Response { - count: (UInt::try_from( - services - .key_backups - .count_keys(sender_user, &body.version)?, - ) - .expect("user backup keys count should not be that high")), - etag: services.key_backups.get_etag(sender_user, &body.version)?, + count: services + .key_backups + .count_keys(sender_user, &body.version) + .await + .try_into()?, + etag: services + .key_backups + .get_etag(sender_user, &body.version) + .await, }) } @@ -342,15 +363,18 @@ pub(crate) async fn delete_backup_keys_for_session_route( services .key_backups - .delete_room_key(sender_user, &body.version, &body.room_id, &body.session_id)?; + .delete_room_key(sender_user, &body.version, &body.room_id, &body.session_id) + .await; Ok(delete_backup_keys_for_session::v3::Response { - count: (UInt::try_from( - services - .key_backups - .count_keys(sender_user, &body.version)?, - ) - .expect("user backup keys count should not be that high")), - etag: services.key_backups.get_etag(sender_user, &body.version)?, + count: services + .key_backups + .count_keys(sender_user, &body.version) + .await + .try_into()?, + etag: services + .key_backups + .get_etag(sender_user, &body.version) + .await, }) } diff --git a/src/api/client/capabilities.rs b/src/api/client/capabilities.rs index 83e1dc7e6..89157e471 100644 --- a/src/api/client/capabilities.rs +++ b/src/api/client/capabilities.rs @@ -3,7 +3,8 @@ use std::collections::BTreeMap; use axum::extract::State; use ruma::{ api::client::discovery::get_capabilities::{ - self, Capabilities, RoomVersionStability, RoomVersionsCapability, ThirdPartyIdChangesCapability, + self, Capabilities, GetLoginTokenCapability, RoomVersionStability, RoomVersionsCapability, + ThirdPartyIdChangesCapability, }, RoomVersionId, }; @@ -43,6 +44,11 @@ pub(crate) async fn get_capabilities_route( enabled: false, }; + // we dont support generating tokens yet + capabilities.get_login_token = GetLoginTokenCapability { + enabled: false, + }; + // MSC4133 capability capabilities .set("uk.tcpip.msc4133.profile_fields", json!({"enabled": true})) diff --git a/src/api/client/config.rs b/src/api/client/config.rs index 61cc97ff5..33b85136c 100644 --- a/src/api/client/config.rs +++ b/src/api/client/config.rs @@ -1,4 +1,5 @@ use axum::extract::State; +use conduit::err; use ruma::{ api::client::{ config::{get_global_account_data, get_room_account_data, set_global_account_data, set_room_account_data}, @@ -25,7 +26,8 @@ pub(crate) async fn set_global_account_data_route( &body.sender_user, &body.event_type.to_string(), body.data.json(), - )?; + ) + .await?; Ok(set_global_account_data::v3::Response {}) } @@ -42,7 +44,8 @@ pub(crate) async fn set_room_account_data_route( &body.sender_user, &body.event_type.to_string(), body.data.json(), - )?; + ) + .await?; Ok(set_room_account_data::v3::Response {}) } @@ -57,8 +60,9 @@ pub(crate) async fn get_global_account_data_route( let event: Box = services .account_data - .get(None, sender_user, body.event_type.to_string().into())? - .ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Data not found."))?; + .get(None, sender_user, body.event_type.to_string().into()) + .await + .map_err(|_| err!(Request(NotFound("Data not found."))))?; let account_data = serde_json::from_str::(event.get()) .map_err(|_| Error::bad_database("Invalid account data event in db."))? @@ -79,8 +83,9 @@ pub(crate) async fn get_room_account_data_route( let event: Box = services .account_data - .get(Some(&body.room_id), sender_user, body.event_type.clone())? - .ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Data not found."))?; + .get(Some(&body.room_id), sender_user, body.event_type.clone()) + .await + .map_err(|_| err!(Request(NotFound("Data not found."))))?; let account_data = serde_json::from_str::(event.get()) .map_err(|_| Error::bad_database("Invalid account data event in db."))? @@ -91,7 +96,7 @@ pub(crate) async fn get_room_account_data_route( }) } -fn set_account_data( +async fn set_account_data( services: &Services, room_id: Option<&RoomId>, sender_user: &Option, event_type: &str, data: &RawJsonValue, ) -> Result<()> { @@ -100,15 +105,18 @@ fn set_account_data( let data: serde_json::Value = serde_json::from_str(data.get()).map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Data is invalid."))?; - services.account_data.update( - room_id, - sender_user, - event_type.into(), - &json!({ - "type": event_type, - "content": data, - }), - )?; + services + .account_data + .update( + room_id, + sender_user, + event_type.into(), + &json!({ + "type": event_type, + "content": data, + }), + ) + .await?; Ok(()) } diff --git a/src/api/client/context.rs b/src/api/client/context.rs index f223d4889..cc49b763f 100644 --- a/src/api/client/context.rs +++ b/src/api/client/context.rs @@ -1,13 +1,14 @@ use std::collections::HashSet; use axum::extract::State; +use conduit::{err, error, Err}; +use futures::StreamExt; use ruma::{ - api::client::{context::get_context, error::ErrorKind, filter::LazyLoadOptions}, + api::client::{context::get_context, filter::LazyLoadOptions}, events::StateEventType, }; -use tracing::error; -use crate::{Error, Result, Ruma}; +use crate::{Result, Ruma}; /// # `GET /_matrix/client/r0/rooms/{roomId}/context` /// @@ -35,34 +36,33 @@ pub(crate) async fn get_context_route( let base_token = services .rooms .timeline - .get_pdu_count(&body.event_id)? - .ok_or(Error::BadRequest(ErrorKind::NotFound, "Base event id not found."))?; + .get_pdu_count(&body.event_id) + .await + .map_err(|_| err!(Request(NotFound("Base event id not found."))))?; let base_event = services .rooms .timeline - .get_pdu(&body.event_id)? - .ok_or(Error::BadRequest(ErrorKind::NotFound, "Base event not found."))?; + .get_pdu(&body.event_id) + .await + .map_err(|_| err!(Request(NotFound("Base event not found."))))?; - let room_id = base_event.room_id.clone(); + let room_id = &base_event.room_id; if !services .rooms .state_accessor - .user_can_see_event(sender_user, &room_id, &body.event_id)? + .user_can_see_event(sender_user, room_id, &body.event_id) + .await { - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "You don't have permission to view this event.", - )); + return Err!(Request(Forbidden("You don't have permission to view this event."))); } - if !services.rooms.lazy_loading.lazy_load_was_sent_before( - sender_user, - sender_device, - &room_id, - &base_event.sender, - )? || lazy_load_send_redundant + if !services + .rooms + .lazy_loading + .lazy_load_was_sent_before(sender_user, sender_device, room_id, &base_event.sender) + .await || lazy_load_send_redundant { lazy_loaded.insert(base_event.sender.as_str().to_owned()); } @@ -75,25 +75,26 @@ pub(crate) async fn get_context_route( let events_before: Vec<_> = services .rooms .timeline - .pdus_until(sender_user, &room_id, base_token)? + .pdus_until(sender_user, room_id, base_token) + .await? .take(limit / 2) - .filter_map(Result::ok) // Remove buggy events - .filter(|(_, pdu)| { + .filter_map(|(count, pdu)| async move { services .rooms .state_accessor - .user_can_see_event(sender_user, &room_id, &pdu.event_id) - .unwrap_or(false) + .user_can_see_event(sender_user, room_id, &pdu.event_id) + .await + .then_some((count, pdu)) }) - .collect(); + .collect() + .await; for (_, event) in &events_before { - if !services.rooms.lazy_loading.lazy_load_was_sent_before( - sender_user, - sender_device, - &room_id, - &event.sender, - )? || lazy_load_send_redundant + if !services + .rooms + .lazy_loading + .lazy_load_was_sent_before(sender_user, sender_device, room_id, &event.sender) + .await || lazy_load_send_redundant { lazy_loaded.insert(event.sender.as_str().to_owned()); } @@ -111,25 +112,26 @@ pub(crate) async fn get_context_route( let events_after: Vec<_> = services .rooms .timeline - .pdus_after(sender_user, &room_id, base_token)? + .pdus_after(sender_user, room_id, base_token) + .await? .take(limit / 2) - .filter_map(Result::ok) // Remove buggy events - .filter(|(_, pdu)| { + .filter_map(|(count, pdu)| async move { services .rooms .state_accessor - .user_can_see_event(sender_user, &room_id, &pdu.event_id) - .unwrap_or(false) + .user_can_see_event(sender_user, room_id, &pdu.event_id) + .await + .then_some((count, pdu)) }) - .collect(); + .collect() + .await; for (_, event) in &events_after { - if !services.rooms.lazy_loading.lazy_load_was_sent_before( - sender_user, - sender_device, - &room_id, - &event.sender, - )? || lazy_load_send_redundant + if !services + .rooms + .lazy_loading + .lazy_load_was_sent_before(sender_user, sender_device, room_id, &event.sender) + .await || lazy_load_send_redundant { lazy_loaded.insert(event.sender.as_str().to_owned()); } @@ -142,12 +144,14 @@ pub(crate) async fn get_context_route( events_after .last() .map_or(&*body.event_id, |(_, e)| &*e.event_id), - )? + ) + .await .map_or( services .rooms .state - .get_room_shortstatehash(&room_id)? + .get_room_shortstatehash(room_id) + .await .expect("All rooms have state"), |hash| hash, ); @@ -156,7 +160,8 @@ pub(crate) async fn get_context_route( .rooms .state_accessor .state_full_ids(shortstatehash) - .await?; + .await + .map_err(|e| err!(Database("State not found: {e}")))?; let end_token = events_after .last() @@ -173,18 +178,19 @@ pub(crate) async fn get_context_route( let (event_type, state_key) = services .rooms .short - .get_statekey_from_short(shortstatekey)?; + .get_statekey_from_short(shortstatekey) + .await?; if event_type != StateEventType::RoomMember { - let Some(pdu) = services.rooms.timeline.get_pdu(&id)? else { - error!("Pdu in state not found: {}", id); + let Ok(pdu) = services.rooms.timeline.get_pdu(&id).await else { + error!("Pdu in state not found: {id}"); continue; }; state.push(pdu.to_state_event()); } else if !lazy_load_enabled || lazy_loaded.contains(&state_key) { - let Some(pdu) = services.rooms.timeline.get_pdu(&id)? else { - error!("Pdu in state not found: {}", id); + let Ok(pdu) = services.rooms.timeline.get_pdu(&id).await else { + error!("Pdu in state not found: {id}"); continue; }; diff --git a/src/api/client/device.rs b/src/api/client/device.rs index bad7f2844..93eaa393d 100644 --- a/src/api/client/device.rs +++ b/src/api/client/device.rs @@ -1,4 +1,6 @@ use axum::extract::State; +use conduit::{err, Err}; +use futures::StreamExt; use ruma::api::client::{ device::{self, delete_device, delete_devices, get_device, get_devices, update_device}, error::ErrorKind, @@ -19,8 +21,8 @@ pub(crate) async fn get_devices_route( let devices: Vec = services .users .all_devices_metadata(sender_user) - .filter_map(Result::ok) // Filter out buggy devices - .collect(); + .collect() + .await; Ok(get_devices::v3::Response { devices, @@ -37,8 +39,9 @@ pub(crate) async fn get_device_route( let device = services .users - .get_device_metadata(sender_user, &body.body.device_id)? - .ok_or(Error::BadRequest(ErrorKind::NotFound, "Device not found."))?; + .get_device_metadata(sender_user, &body.body.device_id) + .await + .map_err(|_| err!(Request(NotFound("Device not found."))))?; Ok(get_device::v3::Response { device, @@ -55,14 +58,16 @@ pub(crate) async fn update_device_route( let mut device = services .users - .get_device_metadata(sender_user, &body.device_id)? - .ok_or(Error::BadRequest(ErrorKind::NotFound, "Device not found."))?; + .get_device_metadata(sender_user, &body.device_id) + .await + .map_err(|_| err!(Request(NotFound("Device not found."))))?; device.display_name.clone_from(&body.display_name); services .users - .update_device_metadata(sender_user, &body.device_id, &device)?; + .update_device_metadata(sender_user, &body.device_id, &device) + .await?; Ok(update_device::v3::Response {}) } @@ -97,22 +102,28 @@ pub(crate) async fn delete_device_route( if let Some(auth) = &body.auth { let (worked, uiaainfo) = services .uiaa - .try_auth(sender_user, sender_device, auth, &uiaainfo)?; + .try_auth(sender_user, sender_device, auth, &uiaainfo) + .await?; + if !worked { - return Err(Error::Uiaa(uiaainfo)); + return Err!(Uiaa(uiaainfo)); } // Success! } else if let Some(json) = body.json_body { uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); services .uiaa - .create(sender_user, sender_device, &uiaainfo, &json)?; - return Err(Error::Uiaa(uiaainfo)); + .create(sender_user, sender_device, &uiaainfo, &json); + + return Err!(Uiaa(uiaainfo)); } else { - return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); + return Err!(Request(NotJson("Not json."))); } - services.users.remove_device(sender_user, &body.device_id)?; + services + .users + .remove_device(sender_user, &body.device_id) + .await; Ok(delete_device::v3::Response {}) } @@ -149,7 +160,9 @@ pub(crate) async fn delete_devices_route( if let Some(auth) = &body.auth { let (worked, uiaainfo) = services .uiaa - .try_auth(sender_user, sender_device, auth, &uiaainfo)?; + .try_auth(sender_user, sender_device, auth, &uiaainfo) + .await?; + if !worked { return Err(Error::Uiaa(uiaainfo)); } @@ -158,14 +171,15 @@ pub(crate) async fn delete_devices_route( uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); services .uiaa - .create(sender_user, sender_device, &uiaainfo, &json)?; + .create(sender_user, sender_device, &uiaainfo, &json); + return Err(Error::Uiaa(uiaainfo)); } else { return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); } for device_id in &body.devices { - services.users.remove_device(sender_user, device_id)?; + services.users.remove_device(sender_user, device_id).await; } Ok(delete_devices::v3::Response {}) diff --git a/src/api/client/directory.rs b/src/api/client/directory.rs index 602f876a9..ea499545c 100644 --- a/src/api/client/directory.rs +++ b/src/api/client/directory.rs @@ -1,6 +1,7 @@ use axum::extract::State; use axum_client_ip::InsecureClientIp; -use conduit::{err, info, warn, Err, Error, Result}; +use conduit::{info, warn, Err, Error, Result}; +use futures::{StreamExt, TryFutureExt}; use ruma::{ api::{ client::{ @@ -18,7 +19,7 @@ use ruma::{ }, StateEventType, }, - uint, RoomId, ServerName, UInt, UserId, + uint, OwnedRoomId, RoomId, ServerName, UInt, UserId, }; use service::Services; @@ -119,16 +120,22 @@ pub(crate) async fn set_room_visibility_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if !services.rooms.metadata.exists(&body.room_id)? { + if !services.rooms.metadata.exists(&body.room_id).await { // Return 404 if the room doesn't exist return Err(Error::BadRequest(ErrorKind::NotFound, "Room not found")); } - if services.users.is_deactivated(sender_user).unwrap_or(false) && body.appservice_info.is_none() { + if services + .users + .is_deactivated(sender_user) + .await + .unwrap_or(false) + && body.appservice_info.is_none() + { return Err!(Request(Forbidden("Guests cannot publish to room directories"))); } - if !user_can_publish_room(&services, sender_user, &body.room_id)? { + if !user_can_publish_room(&services, sender_user, &body.room_id).await? { return Err(Error::BadRequest( ErrorKind::forbidden(), "User is not allowed to publish this room", @@ -138,7 +145,7 @@ pub(crate) async fn set_room_visibility_route( match &body.visibility { room::Visibility::Public => { if services.globals.config.lockdown_public_room_directory - && !services.users.is_admin(sender_user)? + && !services.users.is_admin(sender_user).await && body.appservice_info.is_none() { info!( @@ -164,7 +171,7 @@ pub(crate) async fn set_room_visibility_route( )); } - services.rooms.directory.set_public(&body.room_id)?; + services.rooms.directory.set_public(&body.room_id); if services.globals.config.admin_room_notices { services @@ -174,7 +181,7 @@ pub(crate) async fn set_room_visibility_route( } info!("{sender_user} made {0} public to the room directory", body.room_id); }, - room::Visibility::Private => services.rooms.directory.set_not_public(&body.room_id)?, + room::Visibility::Private => services.rooms.directory.set_not_public(&body.room_id), _ => { return Err(Error::BadRequest( ErrorKind::InvalidParam, @@ -192,13 +199,13 @@ pub(crate) async fn set_room_visibility_route( pub(crate) async fn get_room_visibility_route( State(services): State, body: Ruma, ) -> Result { - if !services.rooms.metadata.exists(&body.room_id)? { + if !services.rooms.metadata.exists(&body.room_id).await { // Return 404 if the room doesn't exist return Err(Error::BadRequest(ErrorKind::NotFound, "Room not found")); } Ok(get_room_visibility::v3::Response { - visibility: if services.rooms.directory.is_public_room(&body.room_id)? { + visibility: if services.rooms.directory.is_public_room(&body.room_id).await { room::Visibility::Public } else { room::Visibility::Private @@ -257,101 +264,41 @@ pub(crate) async fn get_public_rooms_filtered_helper( } } - let mut all_rooms: Vec<_> = services + let mut all_rooms: Vec = services .rooms .directory .public_rooms() - .map(|room_id| { - let room_id = room_id?; - - let chunk = PublicRoomsChunk { - canonical_alias: services - .rooms - .state_accessor - .get_canonical_alias(&room_id)?, - name: services.rooms.state_accessor.get_name(&room_id)?, - num_joined_members: services - .rooms - .state_cache - .room_joined_count(&room_id)? - .unwrap_or_else(|| { - warn!("Room {} has no member count", room_id); - 0 - }) - .try_into() - .expect("user count should not be that big"), - topic: services - .rooms - .state_accessor - .get_room_topic(&room_id) - .unwrap_or(None), - world_readable: services.rooms.state_accessor.is_world_readable(&room_id)?, - guest_can_join: services - .rooms - .state_accessor - .guest_can_join(&room_id)?, - avatar_url: services - .rooms - .state_accessor - .get_avatar(&room_id)? - .into_option() - .unwrap_or_default() - .url, - join_rule: services - .rooms - .state_accessor - .room_state_get(&room_id, &StateEventType::RoomJoinRules, "")? - .map(|s| { - serde_json::from_str(s.content.get()) - .map(|c: RoomJoinRulesEventContent| match c.join_rule { - JoinRule::Public => Some(PublicRoomJoinRule::Public), - JoinRule::Knock => Some(PublicRoomJoinRule::Knock), - _ => None, - }) - .map_err(|e| { - err!(Database(error!("Invalid room join rule event in database: {e}"))) - }) - }) - .transpose()? - .flatten() - .ok_or_else(|| Error::bad_database("Missing room join rule event for room."))?, - room_type: services - .rooms - .state_accessor - .get_room_type(&room_id)?, - room_id, - }; - Ok(chunk) - }) - .filter_map(|r: Result<_>| r.ok()) // Filter out buggy rooms - .filter(|chunk| { + .map(ToOwned::to_owned) + .then(|room_id| public_rooms_chunk(services, room_id)) + .filter_map(|chunk| async move { if let Some(query) = filter.generic_search_term.as_ref().map(|q| q.to_lowercase()) { if let Some(name) = &chunk.name { if name.as_str().to_lowercase().contains(&query) { - return true; + return Some(chunk); } } if let Some(topic) = &chunk.topic { if topic.to_lowercase().contains(&query) { - return true; + return Some(chunk); } } if let Some(canonical_alias) = &chunk.canonical_alias { if canonical_alias.as_str().to_lowercase().contains(&query) { - return true; + return Some(chunk); } } - false - } else { - // No search term - true + return None; } + + // No search term + Some(chunk) }) // We need to collect all, so we can sort by member count - .collect(); + .collect() + .await; all_rooms.sort_by(|l, r| r.num_joined_members.cmp(&l.num_joined_members)); @@ -394,22 +341,23 @@ pub(crate) async fn get_public_rooms_filtered_helper( /// Check whether the user can publish to the room directory via power levels of /// room history visibility event or room creator -fn user_can_publish_room(services: &Services, user_id: &UserId, room_id: &RoomId) -> Result { - if let Some(event) = services +async fn user_can_publish_room(services: &Services, user_id: &UserId, room_id: &RoomId) -> Result { + if let Ok(event) = services .rooms .state_accessor - .room_state_get(room_id, &StateEventType::RoomPowerLevels, "")? + .room_state_get(room_id, &StateEventType::RoomPowerLevels, "") + .await { serde_json::from_str(event.content.get()) .map_err(|_| Error::bad_database("Invalid event content for m.room.power_levels")) .map(|content: RoomPowerLevelsEventContent| { RoomPowerLevels::from(content).user_can_send_state(user_id, StateEventType::RoomHistoryVisibility) }) - } else if let Some(event) = - services - .rooms - .state_accessor - .room_state_get(room_id, &StateEventType::RoomCreate, "")? + } else if let Ok(event) = services + .rooms + .state_accessor + .room_state_get(room_id, &StateEventType::RoomCreate, "") + .await { Ok(event.sender == user_id) } else { @@ -419,3 +367,61 @@ fn user_can_publish_room(services: &Services, user_id: &UserId, room_id: &RoomId )); } } + +async fn public_rooms_chunk(services: &Services, room_id: OwnedRoomId) -> PublicRoomsChunk { + PublicRoomsChunk { + canonical_alias: services + .rooms + .state_accessor + .get_canonical_alias(&room_id) + .await + .ok(), + name: services.rooms.state_accessor.get_name(&room_id).await.ok(), + num_joined_members: services + .rooms + .state_cache + .room_joined_count(&room_id) + .await + .unwrap_or(0) + .try_into() + .expect("joined count overflows ruma UInt"), + topic: services + .rooms + .state_accessor + .get_room_topic(&room_id) + .await + .ok(), + world_readable: services + .rooms + .state_accessor + .is_world_readable(&room_id) + .await, + guest_can_join: services.rooms.state_accessor.guest_can_join(&room_id).await, + avatar_url: services + .rooms + .state_accessor + .get_avatar(&room_id) + .await + .into_option() + .unwrap_or_default() + .url, + join_rule: services + .rooms + .state_accessor + .room_state_get_content(&room_id, &StateEventType::RoomJoinRules, "") + .map_ok(|c: RoomJoinRulesEventContent| match c.join_rule { + JoinRule::Public => PublicRoomJoinRule::Public, + JoinRule::Knock => PublicRoomJoinRule::Knock, + _ => "invite".into(), + }) + .await + .unwrap_or_default(), + room_type: services + .rooms + .state_accessor + .get_room_type(&room_id) + .await + .ok(), + room_id, + } +} diff --git a/src/api/client/filter.rs b/src/api/client/filter.rs index 8b2690c69..2a8ebb9c2 100644 --- a/src/api/client/filter.rs +++ b/src/api/client/filter.rs @@ -1,10 +1,8 @@ use axum::extract::State; -use ruma::api::client::{ - error::ErrorKind, - filter::{create_filter, get_filter}, -}; +use conduit::err; +use ruma::api::client::filter::{create_filter, get_filter}; -use crate::{Error, Result, Ruma}; +use crate::{Result, Ruma}; /// # `GET /_matrix/client/r0/user/{userId}/filter/{filterId}` /// @@ -15,11 +13,13 @@ pub(crate) async fn get_filter_route( State(services): State, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let Some(filter) = services.users.get_filter(sender_user, &body.filter_id)? else { - return Err(Error::BadRequest(ErrorKind::NotFound, "Filter not found.")); - }; - Ok(get_filter::v3::Response::new(filter)) + services + .users + .get_filter(sender_user, &body.filter_id) + .await + .map(get_filter::v3::Response::new) + .map_err(|_| err!(Request(NotFound("Filter not found.")))) } /// # `PUT /_matrix/client/r0/user/{userId}/filter` @@ -29,7 +29,8 @@ pub(crate) async fn create_filter_route( State(services): State, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - Ok(create_filter::v3::Response::new( - services.users.create_filter(sender_user, &body.filter)?, - )) + + let filter_id = services.users.create_filter(sender_user, &body.filter); + + Ok(create_filter::v3::Response::new(filter_id)) } diff --git a/src/api/client/keys.rs b/src/api/client/keys.rs index a426364a2..254d92ccd 100644 --- a/src/api/client/keys.rs +++ b/src/api/client/keys.rs @@ -4,8 +4,8 @@ use std::{ }; use axum::extract::State; -use conduit::{utils, utils::math::continue_exponential_backoff_secs, Err, Error, Result}; -use futures_util::{stream::FuturesUnordered, StreamExt}; +use conduit::{err, utils, utils::math::continue_exponential_backoff_secs, Err, Error, Result}; +use futures::{stream::FuturesUnordered, StreamExt}; use ruma::{ api::{ client::{ @@ -21,7 +21,10 @@ use ruma::{ use serde_json::json; use super::SESSION_ID_LENGTH; -use crate::{service::Services, Ruma}; +use crate::{ + service::{users::parse_master_key, Services}, + Ruma, +}; /// # `POST /_matrix/client/r0/keys/upload` /// @@ -39,7 +42,8 @@ pub(crate) async fn upload_keys_route( for (key_key, key_value) in &body.one_time_keys { services .users - .add_one_time_key(sender_user, sender_device, key_key, key_value)?; + .add_one_time_key(sender_user, sender_device, key_key, key_value) + .await?; } if let Some(device_keys) = &body.device_keys { @@ -47,19 +51,22 @@ pub(crate) async fn upload_keys_route( // This check is needed to assure that signatures are kept if services .users - .get_device_keys(sender_user, sender_device)? - .is_none() + .get_device_keys(sender_user, sender_device) + .await + .is_err() { services .users - .add_device_keys(sender_user, sender_device, device_keys)?; + .add_device_keys(sender_user, sender_device, device_keys) + .await; } } Ok(upload_keys::v3::Response { one_time_key_counts: services .users - .count_one_time_keys(sender_user, sender_device)?, + .count_one_time_keys(sender_user, sender_device) + .await, }) } @@ -120,7 +127,9 @@ pub(crate) async fn upload_signing_keys_route( if let Some(auth) = &body.auth { let (worked, uiaainfo) = services .uiaa - .try_auth(sender_user, sender_device, auth, &uiaainfo)?; + .try_auth(sender_user, sender_device, auth, &uiaainfo) + .await?; + if !worked { return Err(Error::Uiaa(uiaainfo)); } @@ -129,20 +138,24 @@ pub(crate) async fn upload_signing_keys_route( uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); services .uiaa - .create(sender_user, sender_device, &uiaainfo, &json)?; + .create(sender_user, sender_device, &uiaainfo, &json); + return Err(Error::Uiaa(uiaainfo)); } else { return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); } if let Some(master_key) = &body.master_key { - services.users.add_cross_signing_keys( - sender_user, - master_key, - &body.self_signing_key, - &body.user_signing_key, - true, // notify so that other users see the new keys - )?; + services + .users + .add_cross_signing_keys( + sender_user, + master_key, + &body.self_signing_key, + &body.user_signing_key, + true, // notify so that other users see the new keys + ) + .await?; } Ok(upload_signing_keys::v3::Response {}) @@ -179,9 +192,11 @@ pub(crate) async fn upload_signatures_route( .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Invalid signature value."))? .to_owned(), ); + services .users - .sign_key(user_id, key_id, signature, sender_user)?; + .sign_key(user_id, key_id, signature, sender_user) + .await?; } } } @@ -204,56 +219,51 @@ pub(crate) async fn get_key_changes_route( let mut device_list_updates = HashSet::new(); + let from = body + .from + .parse() + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `from`."))?; + + let to = body + .to + .parse() + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `to`."))?; + device_list_updates.extend( services .users - .keys_changed( - sender_user.as_str(), - body.from - .parse() - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `from`."))?, - Some( - body.to - .parse() - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `to`."))?, - ), - ) - .filter_map(Result::ok), + .keys_changed(sender_user.as_str(), from, Some(to)) + .map(ToOwned::to_owned) + .collect::>() + .await, ); - for room_id in services - .rooms - .state_cache - .rooms_joined(sender_user) - .filter_map(Result::ok) - { + let mut rooms_joined = services.rooms.state_cache.rooms_joined(sender_user).boxed(); + + while let Some(room_id) = rooms_joined.next().await { device_list_updates.extend( services .users - .keys_changed( - room_id.as_ref(), - body.from - .parse() - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `from`."))?, - Some( - body.to - .parse() - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `to`."))?, - ), - ) - .filter_map(Result::ok), + .keys_changed(room_id.as_str(), from, Some(to)) + .map(ToOwned::to_owned) + .collect::>() + .await, ); } + Ok(get_key_changes::v3::Response { changed: device_list_updates.into_iter().collect(), left: Vec::new(), // TODO }) } -pub(crate) async fn get_keys_helper bool + Send>( +pub(crate) async fn get_keys_helper( services: &Services, sender_user: Option<&UserId>, device_keys_input: &BTreeMap>, allowed_signatures: F, include_display_names: bool, -) -> Result { +) -> Result +where + F: Fn(&UserId) -> bool + Send + Sync, +{ let mut master_keys = BTreeMap::new(); let mut self_signing_keys = BTreeMap::new(); let mut user_signing_keys = BTreeMap::new(); @@ -274,56 +284,60 @@ pub(crate) async fn get_keys_helper bool + Send>( if device_ids.is_empty() { let mut container = BTreeMap::new(); - for device_id in services.users.all_device_ids(user_id) { - let device_id = device_id?; - if let Some(mut keys) = services.users.get_device_keys(user_id, &device_id)? { + let mut devices = services.users.all_device_ids(user_id).boxed(); + + while let Some(device_id) = devices.next().await { + if let Ok(mut keys) = services.users.get_device_keys(user_id, device_id).await { let metadata = services .users - .get_device_metadata(user_id, &device_id)? - .ok_or_else(|| Error::bad_database("all_device_keys contained nonexistent device."))?; + .get_device_metadata(user_id, device_id) + .await + .map_err(|_| err!(Database("all_device_keys contained nonexistent device.")))?; add_unsigned_device_display_name(&mut keys, metadata, include_display_names) - .map_err(|_| Error::bad_database("invalid device keys in database"))?; + .map_err(|_| err!(Database("invalid device keys in database")))?; - container.insert(device_id, keys); + container.insert(device_id.to_owned(), keys); } } + device_keys.insert(user_id.to_owned(), container); } else { for device_id in device_ids { let mut container = BTreeMap::new(); - if let Some(mut keys) = services.users.get_device_keys(user_id, device_id)? { + if let Ok(mut keys) = services.users.get_device_keys(user_id, device_id).await { let metadata = services .users - .get_device_metadata(user_id, device_id)? - .ok_or(Error::BadRequest( - ErrorKind::InvalidParam, - "Tried to get keys for nonexistent device.", - ))?; + .get_device_metadata(user_id, device_id) + .await + .map_err(|_| err!(Request(InvalidParam("Tried to get keys for nonexistent device."))))?; add_unsigned_device_display_name(&mut keys, metadata, include_display_names) - .map_err(|_| Error::bad_database("invalid device keys in database"))?; + .map_err(|_| err!(Database("invalid device keys in database")))?; + container.insert(device_id.to_owned(), keys); } + device_keys.insert(user_id.to_owned(), container); } } - if let Some(master_key) = services + if let Ok(master_key) = services .users - .get_master_key(sender_user, user_id, &allowed_signatures)? + .get_master_key(sender_user, user_id, &allowed_signatures) + .await { master_keys.insert(user_id.to_owned(), master_key); } - if let Some(self_signing_key) = - services - .users - .get_self_signing_key(sender_user, user_id, &allowed_signatures)? + if let Ok(self_signing_key) = services + .users + .get_self_signing_key(sender_user, user_id, &allowed_signatures) + .await { self_signing_keys.insert(user_id.to_owned(), self_signing_key); } if Some(user_id) == sender_user { - if let Some(user_signing_key) = services.users.get_user_signing_key(user_id)? { + if let Ok(user_signing_key) = services.users.get_user_signing_key(user_id).await { user_signing_keys.insert(user_id.to_owned(), user_signing_key); } } @@ -386,23 +400,26 @@ pub(crate) async fn get_keys_helper bool + Send>( while let Some((server, response)) = futures.next().await { if let Ok(Ok(response)) = response { for (user, masterkey) in response.master_keys { - let (master_key_id, mut master_key) = services.users.parse_master_key(&user, &masterkey)?; + let (master_key_id, mut master_key) = parse_master_key(&user, &masterkey)?; - if let Some(our_master_key) = - services - .users - .get_key(&master_key_id, sender_user, &user, &allowed_signatures)? + if let Ok(our_master_key) = services + .users + .get_key(&master_key_id, sender_user, &user, &allowed_signatures) + .await { - let (_, our_master_key) = services.users.parse_master_key(&user, &our_master_key)?; + let (_, our_master_key) = parse_master_key(&user, &our_master_key)?; master_key.signatures.extend(our_master_key.signatures); } let json = serde_json::to_value(master_key).expect("to_value always works"); let raw = serde_json::from_value(json).expect("Raw::from_value always works"); - services.users.add_cross_signing_keys( - &user, &raw, &None, &None, - false, /* Dont notify. A notification would trigger another key request resulting in an - * endless loop */ - )?; + services + .users + .add_cross_signing_keys( + &user, &raw, &None, &None, + false, /* Dont notify. A notification would trigger another key request resulting in an + * endless loop */ + ) + .await?; master_keys.insert(user.clone(), raw); } @@ -465,9 +482,10 @@ pub(crate) async fn claim_keys_helper( let mut container = BTreeMap::new(); for (device_id, key_algorithm) in map { - if let Some(one_time_keys) = services + if let Ok(one_time_keys) = services .users - .take_one_time_key(user_id, device_id, key_algorithm)? + .take_one_time_key(user_id, device_id, key_algorithm) + .await { let mut c = BTreeMap::new(); c.insert(one_time_keys.0, one_time_keys.1); diff --git a/src/api/client/membership.rs b/src/api/client/membership.rs index 470db6693..6e3bc8940 100644 --- a/src/api/client/membership.rs +++ b/src/api/client/membership.rs @@ -11,9 +11,10 @@ use conduit::{ debug, debug_error, debug_warn, err, error, info, pdu::{gen_event_id_canonical_json, PduBuilder}, trace, utils, - utils::math::continue_exponential_backoff_secs, + utils::{math::continue_exponential_backoff_secs, IterStream, ReadyExt}, warn, Err, Error, PduEvent, Result, }; +use futures::{FutureExt, StreamExt}; use ruma::{ api::{ client::{ @@ -55,9 +56,9 @@ async fn banned_room_check( services: &Services, user_id: &UserId, room_id: Option<&RoomId>, server_name: Option<&ServerName>, client_ip: IpAddr, ) -> Result<()> { - if !services.users.is_admin(user_id)? { + if !services.users.is_admin(user_id).await { if let Some(room_id) = room_id { - if services.rooms.metadata.is_banned(room_id)? + if services.rooms.metadata.is_banned(room_id).await || services .globals .config @@ -79,23 +80,22 @@ async fn banned_room_check( "Automatically deactivating user {user_id} due to attempted banned room join from IP \ {client_ip}" ))) - .await; + .await + .ok(); } let all_joined_rooms: Vec = services .rooms .state_cache .rooms_joined(user_id) - .filter_map(Result::ok) - .collect(); + .map(Into::into) + .collect() + .await; - full_user_deactivate(services, user_id, all_joined_rooms).await?; + full_user_deactivate(services, user_id, &all_joined_rooms).await?; } - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "This room is banned on this homeserver.", - )); + return Err!(Request(Forbidden("This room is banned on this homeserver."))); } } else if let Some(server_name) = server_name { if services @@ -119,23 +119,22 @@ async fn banned_room_check( "Automatically deactivating user {user_id} due to attempted banned room join from IP \ {client_ip}" ))) - .await; + .await + .ok(); } let all_joined_rooms: Vec = services .rooms .state_cache .rooms_joined(user_id) - .filter_map(Result::ok) - .collect(); + .map(Into::into) + .collect() + .await; - full_user_deactivate(services, user_id, all_joined_rooms).await?; + full_user_deactivate(services, user_id, &all_joined_rooms).await?; } - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "This remote server is banned on this homeserver.", - )); + return Err!(Request(Forbidden("This remote server is banned on this homeserver."))); } } } @@ -168,18 +167,20 @@ pub(crate) async fn join_room_by_id_route( .await?; // There is no body.server_name for /roomId/join - let mut servers = services + let mut servers: Vec<_> = services .rooms .state_cache .servers_invite_via(&body.room_id) - .filter_map(Result::ok) - .collect::>(); + .map(ToOwned::to_owned) + .collect() + .await; servers.extend( services .rooms .state_cache - .invite_state(sender_user, &body.room_id)? + .invite_state(sender_user, &body.room_id) + .await .unwrap_or_default() .iter() .filter_map(|event| serde_json::from_str(event.json().get()).ok()) @@ -202,6 +203,7 @@ pub(crate) async fn join_room_by_id_route( body.third_party_signed.as_ref(), &body.appservice_info, ) + .boxed() .await } @@ -233,14 +235,17 @@ pub(crate) async fn join_room_by_id_or_alias_route( .rooms .state_cache .servers_invite_via(&room_id) - .filter_map(Result::ok), + .map(ToOwned::to_owned) + .collect::>() + .await, ); servers.extend( services .rooms .state_cache - .invite_state(sender_user, &room_id)? + .invite_state(sender_user, &room_id) + .await .unwrap_or_default() .iter() .filter_map(|event| serde_json::from_str(event.json().get()).ok()) @@ -270,19 +275,23 @@ pub(crate) async fn join_room_by_id_or_alias_route( if let Some(pre_servers) = &mut pre_servers { servers.append(pre_servers); } + servers.extend( services .rooms .state_cache .servers_invite_via(&room_id) - .filter_map(Result::ok), + .map(ToOwned::to_owned) + .collect::>() + .await, ); servers.extend( services .rooms .state_cache - .invite_state(sender_user, &room_id)? + .invite_state(sender_user, &room_id) + .await .unwrap_or_default() .iter() .filter_map(|event| serde_json::from_str(event.json().get()).ok()) @@ -305,6 +314,7 @@ pub(crate) async fn join_room_by_id_or_alias_route( body.third_party_signed.as_ref(), appservice_info, ) + .boxed() .await?; Ok(join_room_by_id_or_alias::v3::Response { @@ -337,7 +347,7 @@ pub(crate) async fn invite_user_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if !services.users.is_admin(sender_user)? && services.globals.block_non_admin_invites() { + if !services.users.is_admin(sender_user).await && services.globals.block_non_admin_invites() { info!( "User {sender_user} is not an admin and attempted to send an invite to room {}", &body.room_id @@ -375,15 +385,13 @@ pub(crate) async fn kick_user_route( services .rooms .state_accessor - .room_state_get(&body.room_id, &StateEventType::RoomMember, body.user_id.as_ref())? - .ok_or(Error::BadRequest( - ErrorKind::BadState, - "Cannot kick member that's not in the room.", - ))? + .room_state_get(&body.room_id, &StateEventType::RoomMember, body.user_id.as_ref()) + .await + .map_err(|_| err!(Request(BadState("Cannot kick member that's not in the room."))))? .content .get(), ) - .map_err(|_| Error::bad_database("Invalid member event in database."))?; + .map_err(|_| err!(Database("Invalid member event in database.")))?; event.membership = MembershipState::Leave; event.reason.clone_from(&body.reason); @@ -421,10 +429,13 @@ pub(crate) async fn ban_user_route( let state_lock = services.rooms.state.mutex.lock(&body.room_id).await; + let blurhash = services.users.blurhash(&body.user_id).await.ok(); + let event = services .rooms .state_accessor - .room_state_get(&body.room_id, &StateEventType::RoomMember, body.user_id.as_ref())? + .room_state_get(&body.room_id, &StateEventType::RoomMember, body.user_id.as_ref()) + .await .map_or( Ok(RoomMemberEventContent { membership: MembershipState::Ban, @@ -432,7 +443,7 @@ pub(crate) async fn ban_user_route( avatar_url: None, is_direct: None, third_party_invite: None, - blurhash: services.users.blurhash(&body.user_id).unwrap_or_default(), + blurhash: blurhash.clone(), reason: body.reason.clone(), join_authorized_via_users_server: None, }), @@ -442,12 +453,12 @@ pub(crate) async fn ban_user_route( membership: MembershipState::Ban, displayname: None, avatar_url: None, - blurhash: services.users.blurhash(&body.user_id).unwrap_or_default(), + blurhash: blurhash.clone(), reason: body.reason.clone(), join_authorized_via_users_server: None, ..event }) - .map_err(|_| Error::bad_database("Invalid member event in database.")) + .map_err(|e| err!(Database("Invalid member event in database: {e:?}"))) }, )?; @@ -488,12 +499,13 @@ pub(crate) async fn unban_user_route( services .rooms .state_accessor - .room_state_get(&body.room_id, &StateEventType::RoomMember, body.user_id.as_ref())? - .ok_or(Error::BadRequest(ErrorKind::BadState, "Cannot unban a user who is not banned."))? + .room_state_get(&body.room_id, &StateEventType::RoomMember, body.user_id.as_ref()) + .await + .map_err(|_| err!(Request(BadState("Cannot unban a user who is not banned."))))? .content .get(), ) - .map_err(|_| Error::bad_database("Invalid member event in database."))?; + .map_err(|e| err!(Database("Invalid member event in database: {e:?}")))?; event.membership = MembershipState::Leave; event.reason.clone_from(&body.reason); @@ -539,18 +551,16 @@ pub(crate) async fn forget_room_route( if services .rooms .state_cache - .is_joined(sender_user, &body.room_id)? + .is_joined(sender_user, &body.room_id) + .await { - return Err(Error::BadRequest( - ErrorKind::Unknown, - "You must leave the room before forgetting it", - )); + return Err!(Request(Unknown("You must leave the room before forgetting it"))); } services .rooms .state_cache - .forget(&body.room_id, sender_user)?; + .forget(&body.room_id, sender_user); Ok(forget_room::v3::Response::new()) } @@ -568,8 +578,9 @@ pub(crate) async fn joined_rooms_route( .rooms .state_cache .rooms_joined(sender_user) - .filter_map(Result::ok) - .collect(), + .map(ToOwned::to_owned) + .collect() + .await, }) } @@ -587,12 +598,10 @@ pub(crate) async fn get_member_events_route( if !services .rooms .state_accessor - .user_can_see_state_events(sender_user, &body.room_id)? + .user_can_see_state_events(sender_user, &body.room_id) + .await { - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "You don't have permission to view this room.", - )); + return Err!(Request(Forbidden("You don't have permission to view this room."))); } Ok(get_member_events::v3::Response { @@ -622,30 +631,28 @@ pub(crate) async fn joined_members_route( if !services .rooms .state_accessor - .user_can_see_state_events(sender_user, &body.room_id)? + .user_can_see_state_events(sender_user, &body.room_id) + .await { - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "You don't have permission to view this room.", - )); + return Err!(Request(Forbidden("You don't have permission to view this room."))); } let joined: BTreeMap = services .rooms .state_cache .room_members(&body.room_id) - .filter_map(|user| { - let user = user.ok()?; - - Some(( + .map(ToOwned::to_owned) + .then(|user| async move { + ( user.clone(), RoomMember { - display_name: services.users.displayname(&user).unwrap_or_default(), - avatar_url: services.users.avatar_url(&user).unwrap_or_default(), + display_name: services.users.displayname(&user).await.ok(), + avatar_url: services.users.avatar_url(&user).await.ok(), }, - )) + ) }) - .collect(); + .collect() + .await; Ok(joined_members::v3::Response { joined, @@ -658,13 +665,23 @@ pub async fn join_room_by_id_helper( ) -> Result { let state_lock = services.rooms.state.mutex.lock(room_id).await; - let user_is_guest = services.users.is_deactivated(sender_user).unwrap_or(false) && appservice_info.is_none(); + let user_is_guest = services + .users + .is_deactivated(sender_user) + .await + .unwrap_or(false) + && appservice_info.is_none(); - if matches!(services.rooms.state_accessor.guest_can_join(room_id), Ok(false)) && user_is_guest { + if user_is_guest && !services.rooms.state_accessor.guest_can_join(room_id).await { return Err!(Request(Forbidden("Guests are not allowed to join this room"))); } - if matches!(services.rooms.state_cache.is_joined(sender_user, room_id), Ok(true)) { + if services + .rooms + .state_cache + .is_joined(sender_user, room_id) + .await + { debug_warn!("{sender_user} is already joined in {room_id}"); return Ok(join_room_by_id::v3::Response { room_id: room_id.into(), @@ -674,15 +691,17 @@ pub async fn join_room_by_id_helper( if services .rooms .state_cache - .server_in_room(services.globals.server_name(), room_id)? - || servers.is_empty() + .server_in_room(services.globals.server_name(), room_id) + .await || servers.is_empty() || (servers.len() == 1 && services.globals.server_is_ours(&servers[0])) { join_room_by_id_helper_local(services, sender_user, room_id, reason, servers, third_party_signed, state_lock) + .boxed() .await } else { // Ask a remote server if we are not participating in this room join_room_by_id_helper_remote(services, sender_user, room_id, reason, servers, third_party_signed, state_lock) + .boxed() .await } } @@ -739,11 +758,11 @@ async fn join_room_by_id_helper_remote( "content".to_owned(), to_canonical_value(RoomMemberEventContent { membership: MembershipState::Join, - displayname: services.users.displayname(sender_user)?, - avatar_url: services.users.avatar_url(sender_user)?, + displayname: services.users.displayname(sender_user).await.ok(), + avatar_url: services.users.avatar_url(sender_user).await.ok(), is_direct: None, third_party_invite: None, - blurhash: services.users.blurhash(sender_user)?, + blurhash: services.users.blurhash(sender_user).await.ok(), reason, join_authorized_via_users_server: join_authorized_via_users_server.clone(), }) @@ -791,10 +810,11 @@ async fn join_room_by_id_helper_remote( federation::membership::create_join_event::v2::Request { room_id: room_id.to_owned(), event_id: event_id.to_owned(), + omit_members: false, pdu: services .sending - .convert_to_outgoing_federation_event(join_event.clone()), - omit_members: false, + .convert_to_outgoing_federation_event(join_event.clone()) + .await, }, ) .await?; @@ -864,7 +884,11 @@ async fn join_room_by_id_helper_remote( } } - services.rooms.short.get_or_create_shortroomid(room_id)?; + services + .rooms + .short + .get_or_create_shortroomid(room_id) + .await; info!("Parsing join event"); let parsed_join_pdu = PduEvent::from_id_val(event_id, join_event.clone()) @@ -895,12 +919,13 @@ async fn join_room_by_id_helper_remote( err!(BadServerResponse("Invalid PDU in send_join response: {e:?}")) })?; - services.rooms.outlier.add_pdu_outlier(&event_id, &value)?; + services.rooms.outlier.add_pdu_outlier(&event_id, &value); if let Some(state_key) = &pdu.state_key { let shortstatekey = services .rooms .short - .get_or_create_shortstatekey(&pdu.kind.to_string().into(), state_key)?; + .get_or_create_shortstatekey(&pdu.kind.to_string().into(), state_key) + .await; state.insert(shortstatekey, pdu.event_id.clone()); } } @@ -916,50 +941,53 @@ async fn join_room_by_id_helper_remote( continue; }; - services.rooms.outlier.add_pdu_outlier(&event_id, &value)?; + services.rooms.outlier.add_pdu_outlier(&event_id, &value); } debug!("Running send_join auth check"); + let fetch_state = &state; + let state_fetch = |k: &'static StateEventType, s: String| async move { + let shortstatekey = services.rooms.short.get_shortstatekey(k, &s).await.ok()?; + + let event_id = fetch_state.get(&shortstatekey)?; + services.rooms.timeline.get_pdu(event_id).await.ok() + }; let auth_check = state_res::event_auth::auth_check( &state_res::RoomVersion::new(&room_version_id).expect("room version is supported"), &parsed_join_pdu, - None::, // TODO: third party invite - |k, s| { - services - .rooms - .timeline - .get_pdu( - state.get( - &services - .rooms - .short - .get_or_create_shortstatekey(&k.to_string().into(), s) - .ok()?, - )?, - ) - .ok()? - }, + None, // TODO: third party invite + |k, s| state_fetch(k, s.to_owned()), ) - .map_err(|e| { - warn!("Auth check failed: {e}"); - Error::BadRequest(ErrorKind::forbidden(), "Auth check failed") - })?; + .await + .map_err(|e| err!(Request(Forbidden(warn!("Auth check failed: {e:?}")))))?; if !auth_check { - return Err(Error::BadRequest(ErrorKind::forbidden(), "Auth check failed")); + return Err!(Request(Forbidden("Auth check failed"))); } info!("Saving state from send_join"); - let (statehash_before_join, new, removed) = services.rooms.state_compressor.save_state( - room_id, - Arc::new( - state - .into_iter() - .map(|(k, id)| services.rooms.state_compressor.compress_state_event(k, &id)) - .collect::>()?, - ), - )?; + let (statehash_before_join, new, removed) = services + .rooms + .state_compressor + .save_state( + room_id, + Arc::new( + state + .into_iter() + .stream() + .then(|(k, id)| async move { + services + .rooms + .state_compressor + .compress_state_event(k, &id) + .await + }) + .collect() + .await, + ), + ) + .await?; services .rooms @@ -968,12 +996,20 @@ async fn join_room_by_id_helper_remote( .await?; info!("Updating joined counts for new room"); - services.rooms.state_cache.update_joined_count(room_id)?; + services + .rooms + .state_cache + .update_joined_count(room_id) + .await; // We append to state before appending the pdu, so we don't have a moment in // time with the pdu without it's state. This is okay because append_pdu can't // fail. - let statehash_after_join = services.rooms.state.append_to_state(&parsed_join_pdu)?; + let statehash_after_join = services + .rooms + .state + .append_to_state(&parsed_join_pdu) + .await?; info!("Appending new room join event"); services @@ -993,7 +1029,7 @@ async fn join_room_by_id_helper_remote( services .rooms .state - .set_room_state(room_id, statehash_after_join, &state_lock)?; + .set_room_state(room_id, statehash_after_join, &state_lock); Ok(join_room_by_id::v3::Response::new(room_id.to_owned())) } @@ -1005,23 +1041,15 @@ async fn join_room_by_id_helper_local( ) -> Result { debug!("We can join locally"); - let join_rules_event = services + let join_rules_event_content = services .rooms .state_accessor - .room_state_get(room_id, &StateEventType::RoomJoinRules, "")?; - - let join_rules_event_content: Option = join_rules_event - .as_ref() - .map(|join_rules_event| { - serde_json::from_str(join_rules_event.content.get()).map_err(|e| { - warn!("Invalid join rules event: {}", e); - Error::bad_database("Invalid join rules event in db.") - }) - }) - .transpose()?; + .room_state_get_content(room_id, &StateEventType::RoomJoinRules, "") + .await + .map(|content: RoomJoinRulesEventContent| content); let restriction_rooms = match join_rules_event_content { - Some(RoomJoinRulesEventContent { + Ok(RoomJoinRulesEventContent { join_rule: JoinRule::Restricted(restricted) | JoinRule::KnockRestricted(restricted), }) => restricted .allow @@ -1034,29 +1062,34 @@ async fn join_room_by_id_helper_local( _ => Vec::new(), }; - let local_members = services + let local_members: Vec<_> = services .rooms .state_cache .room_members(room_id) - .filter_map(Result::ok) - .filter(|user| services.globals.user_is_local(user)) - .collect::>(); + .ready_filter(|user| services.globals.user_is_local(user)) + .map(ToOwned::to_owned) + .collect() + .await; let mut join_authorized_via_users_server: Option = None; - if restriction_rooms.iter().any(|restriction_room_id| { - services - .rooms - .state_cache - .is_joined(sender_user, restriction_room_id) - .unwrap_or(false) - }) { + if restriction_rooms + .iter() + .stream() + .any(|restriction_room_id| { + services + .rooms + .state_cache + .is_joined(sender_user, restriction_room_id) + }) + .await + { for user in local_members { if services .rooms .state_accessor .user_can_invite(room_id, &user, sender_user, &state_lock) - .unwrap_or(false) + .await { join_authorized_via_users_server = Some(user); break; @@ -1066,11 +1099,11 @@ async fn join_room_by_id_helper_local( let event = RoomMemberEventContent { membership: MembershipState::Join, - displayname: services.users.displayname(sender_user)?, - avatar_url: services.users.avatar_url(sender_user)?, + displayname: services.users.displayname(sender_user).await.ok(), + avatar_url: services.users.avatar_url(sender_user).await.ok(), is_direct: None, third_party_invite: None, - blurhash: services.users.blurhash(sender_user)?, + blurhash: services.users.blurhash(sender_user).await.ok(), reason: reason.clone(), join_authorized_via_users_server, }; @@ -1144,11 +1177,11 @@ async fn join_room_by_id_helper_local( "content".to_owned(), to_canonical_value(RoomMemberEventContent { membership: MembershipState::Join, - displayname: services.users.displayname(sender_user)?, - avatar_url: services.users.avatar_url(sender_user)?, + displayname: services.users.displayname(sender_user).await.ok(), + avatar_url: services.users.avatar_url(sender_user).await.ok(), is_direct: None, third_party_invite: None, - blurhash: services.users.blurhash(sender_user)?, + blurhash: services.users.blurhash(sender_user).await.ok(), reason, join_authorized_via_users_server, }) @@ -1195,10 +1228,11 @@ async fn join_room_by_id_helper_local( federation::membership::create_join_event::v2::Request { room_id: room_id.to_owned(), event_id: event_id.to_owned(), + omit_members: false, pdu: services .sending - .convert_to_outgoing_federation_event(join_event.clone()), - omit_members: false, + .convert_to_outgoing_federation_event(join_event.clone()) + .await, }, ) .await?; @@ -1369,7 +1403,7 @@ pub(crate) async fn invite_helper( services: &Services, sender_user: &UserId, user_id: &UserId, room_id: &RoomId, reason: Option, is_direct: bool, ) -> Result<()> { - if !services.users.is_admin(user_id)? && services.globals.block_non_admin_invites() { + if !services.users.is_admin(user_id).await && services.globals.block_non_admin_invites() { info!("User {sender_user} is not an admin and attempted to send an invite to room {room_id}"); return Err(Error::BadRequest( ErrorKind::forbidden(), @@ -1381,7 +1415,7 @@ pub(crate) async fn invite_helper( let (pdu, pdu_json, invite_room_state) = { let state_lock = services.rooms.state.mutex.lock(room_id).await; let content = to_raw_value(&RoomMemberEventContent { - avatar_url: services.users.avatar_url(user_id)?, + avatar_url: services.users.avatar_url(user_id).await.ok(), displayname: None, is_direct: Some(is_direct), membership: MembershipState::Invite, @@ -1392,28 +1426,32 @@ pub(crate) async fn invite_helper( }) .expect("member event is valid value"); - let (pdu, pdu_json) = services.rooms.timeline.create_hash_and_sign_event( - PduBuilder { - event_type: TimelineEventType::RoomMember, - content, - unsigned: None, - state_key: Some(user_id.to_string()), - redacts: None, - timestamp: None, - }, - sender_user, - room_id, - &state_lock, - )?; + let (pdu, pdu_json) = services + .rooms + .timeline + .create_hash_and_sign_event( + PduBuilder { + event_type: TimelineEventType::RoomMember, + content, + unsigned: None, + state_key: Some(user_id.to_string()), + redacts: None, + timestamp: None, + }, + sender_user, + room_id, + &state_lock, + ) + .await?; - let invite_room_state = services.rooms.state.calculate_invite_state(&pdu)?; + let invite_room_state = services.rooms.state.calculate_invite_state(&pdu).await?; drop(state_lock); (pdu, pdu_json, invite_room_state) }; - let room_version_id = services.rooms.state.get_room_version(room_id)?; + let room_version_id = services.rooms.state.get_room_version(room_id).await?; let response = services .sending @@ -1425,9 +1463,15 @@ pub(crate) async fn invite_helper( room_version: room_version_id.clone(), event: services .sending - .convert_to_outgoing_federation_event(pdu_json.clone()), + .convert_to_outgoing_federation_event(pdu_json.clone()) + .await, invite_room_state, - via: services.rooms.state_cache.servers_route_via(room_id).ok(), + via: services + .rooms + .state_cache + .servers_route_via(room_id) + .await + .ok(), }, ) .await?; @@ -1478,11 +1522,16 @@ pub(crate) async fn invite_helper( "Could not accept incoming PDU as timeline event.", ))?; - services.sending.send_pdu_room(room_id, &pdu_id)?; + services.sending.send_pdu_room(room_id, &pdu_id).await?; return Ok(()); } - if !services.rooms.state_cache.is_joined(sender_user, room_id)? { + if !services + .rooms + .state_cache + .is_joined(sender_user, room_id) + .await + { return Err(Error::BadRequest( ErrorKind::forbidden(), "You don't have permission to view this room.", @@ -1499,11 +1548,11 @@ pub(crate) async fn invite_helper( event_type: TimelineEventType::RoomMember, content: to_raw_value(&RoomMemberEventContent { membership: MembershipState::Invite, - displayname: services.users.displayname(user_id)?, - avatar_url: services.users.avatar_url(user_id)?, + displayname: services.users.displayname(user_id).await.ok(), + avatar_url: services.users.avatar_url(user_id).await.ok(), is_direct: Some(is_direct), third_party_invite: None, - blurhash: services.users.blurhash(user_id)?, + blurhash: services.users.blurhash(user_id).await.ok(), reason, join_authorized_via_users_server: None, }) @@ -1527,40 +1576,41 @@ pub(crate) async fn invite_helper( // Make a user leave all their joined rooms, forgets all rooms, and ignores // errors pub async fn leave_all_rooms(services: &Services, user_id: &UserId) { - let all_rooms = services + let all_rooms: Vec<_> = services .rooms .state_cache .rooms_joined(user_id) + .map(ToOwned::to_owned) .chain( services .rooms .state_cache .rooms_invited(user_id) - .map(|t| t.map(|(r, _)| r)), + .map(|(r, _)| r), ) - .collect::>(); + .collect() + .await; for room_id in all_rooms { - let Ok(room_id) = room_id else { - continue; - }; - // ignore errors if let Err(e) = leave_room(services, user_id, &room_id, None).await { warn!(%room_id, %user_id, %e, "Failed to leave room"); } - if let Err(e) = services.rooms.state_cache.forget(&room_id, user_id) { - warn!(%room_id, %user_id, %e, "Failed to forget room"); - } + + services.rooms.state_cache.forget(&room_id, user_id); } } pub async fn leave_room(services: &Services, user_id: &UserId, room_id: &RoomId, reason: Option) -> Result<()> { + //use conduit::utils::stream::OptionStream; + use futures::TryFutureExt; + // Ask a remote server if we don't have this room if !services .rooms .state_cache - .server_in_room(services.globals.server_name(), room_id)? + .server_in_room(services.globals.server_name(), room_id) + .await { if let Err(e) = remote_leave_room(services, user_id, room_id).await { warn!("Failed to leave room {} remotely: {}", user_id, e); @@ -1570,34 +1620,42 @@ pub async fn leave_room(services: &Services, user_id: &UserId, room_id: &RoomId, let last_state = services .rooms .state_cache - .invite_state(user_id, room_id)? - .map_or_else(|| services.rooms.state_cache.left_state(user_id, room_id), |s| Ok(Some(s)))?; + .invite_state(user_id, room_id) + .map_err(|_| services.rooms.state_cache.left_state(user_id, room_id)) + .await + .ok(); // We always drop the invite, we can't rely on other servers - services.rooms.state_cache.update_membership( - room_id, - user_id, - RoomMemberEventContent::new(MembershipState::Leave), - user_id, - last_state, - None, - true, - )?; + services + .rooms + .state_cache + .update_membership( + room_id, + user_id, + RoomMemberEventContent::new(MembershipState::Leave), + user_id, + last_state, + None, + true, + ) + .await?; } else { let state_lock = services.rooms.state.mutex.lock(room_id).await; - let member_event = - services - .rooms - .state_accessor - .room_state_get(room_id, &StateEventType::RoomMember, user_id.as_str())?; + let member_event = services + .rooms + .state_accessor + .room_state_get(room_id, &StateEventType::RoomMember, user_id.as_str()) + .await; // Fix for broken rooms - let member_event = match member_event { - None => { - error!("Trying to leave a room you are not a member of."); + let Ok(member_event) = member_event else { + error!("Trying to leave a room you are not a member of."); - services.rooms.state_cache.update_membership( + services + .rooms + .state_cache + .update_membership( room_id, user_id, RoomMemberEventContent::new(MembershipState::Leave), @@ -1605,16 +1663,14 @@ pub async fn leave_room(services: &Services, user_id: &UserId, room_id: &RoomId, None, None, true, - )?; - return Ok(()); - }, - Some(e) => e, + ) + .await?; + + return Ok(()); }; - let mut event: RoomMemberEventContent = serde_json::from_str(member_event.content.get()).map_err(|e| { - error!("Invalid room member event in database: {}", e); - Error::bad_database("Invalid member event in database.") - })?; + let mut event: RoomMemberEventContent = serde_json::from_str(member_event.content.get()) + .map_err(|e| err!(Database(error!("Invalid room member event in database: {e}"))))?; event.membership = MembershipState::Leave; event.reason = reason; @@ -1647,15 +1703,17 @@ async fn remote_leave_room(services: &Services, user_id: &UserId, room_id: &Room let invite_state = services .rooms .state_cache - .invite_state(user_id, room_id)? - .ok_or(Error::BadRequest(ErrorKind::BadState, "User is not invited."))?; + .invite_state(user_id, room_id) + .await + .map_err(|_| err!(Request(BadState("User is not invited."))))?; let mut servers: HashSet = services .rooms .state_cache .servers_invite_via(room_id) - .filter_map(Result::ok) - .collect(); + .map(ToOwned::to_owned) + .collect() + .await; servers.extend( invite_state @@ -1760,7 +1818,8 @@ async fn remote_leave_room(services: &Services, user_id: &UserId, room_id: &Room event_id, pdu: services .sending - .convert_to_outgoing_federation_event(leave_event.clone()), + .convert_to_outgoing_federation_event(leave_event.clone()) + .await, }, ) .await?; diff --git a/src/api/client/message.rs b/src/api/client/message.rs index 51aee8c12..bab5fa54f 100644 --- a/src/api/client/message.rs +++ b/src/api/client/message.rs @@ -1,7 +1,8 @@ use std::collections::{BTreeMap, HashSet}; use axum::extract::State; -use conduit::PduCount; +use conduit::{err, utils::ReadyExt, Err, PduCount}; +use futures::{FutureExt, StreamExt}; use ruma::{ api::client::{ error::ErrorKind, @@ -9,13 +10,14 @@ use ruma::{ message::{get_message_events, send_message_event}, }, events::{MessageLikeEventType, StateEventType}, - RoomId, UserId, + UserId, }; use serde_json::{from_str, Value}; +use service::rooms::timeline::PdusIterItem; use crate::{ service::{pdu::PduBuilder, Services}, - utils, Error, PduEvent, Result, Ruma, + utils, Error, Result, Ruma, }; /// # `PUT /_matrix/client/v3/rooms/{roomId}/send/{eventType}/{txnId}` @@ -30,79 +32,78 @@ use crate::{ pub(crate) async fn send_message_event_route( State(services): State, body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let sender_user = body.sender_user.as_deref().expect("user is authenticated"); let sender_device = body.sender_device.as_deref(); - - let state_lock = services.rooms.state.mutex.lock(&body.room_id).await; + let appservice_info = body.appservice_info.as_ref(); // Forbid m.room.encrypted if encryption is disabled if MessageLikeEventType::RoomEncrypted == body.event_type && !services.globals.allow_encryption() { - return Err(Error::BadRequest(ErrorKind::forbidden(), "Encryption has been disabled")); + return Err!(Request(Forbidden("Encryption has been disabled"))); } - if body.event_type == MessageLikeEventType::CallInvite && services.rooms.directory.is_public_room(&body.room_id)? { - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "Room call invites are not allowed in public rooms", - )); + let state_lock = services.rooms.state.mutex.lock(&body.room_id).await; + + if body.event_type == MessageLikeEventType::CallInvite + && services.rooms.directory.is_public_room(&body.room_id).await + { + return Err!(Request(Forbidden("Room call invites are not allowed in public rooms"))); } // Check if this is a new transaction id - if let Some(response) = services + if let Ok(response) = services .transaction_ids - .existing_txnid(sender_user, sender_device, &body.txn_id)? + .existing_txnid(sender_user, sender_device, &body.txn_id) + .await { // The client might have sent a txnid of the /sendToDevice endpoint // This txnid has no response associated with it if response.is_empty() { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Tried to use txn id already used for an incompatible endpoint.", - )); + return Err!(Request(InvalidParam( + "Tried to use txn id already used for an incompatible endpoint." + ))); } - let event_id = utils::string_from_bytes(&response) - .map_err(|_| Error::bad_database("Invalid txnid bytes in database."))? - .try_into() - .map_err(|_| Error::bad_database("Invalid event id in txnid data."))?; return Ok(send_message_event::v3::Response { - event_id, + event_id: utils::string_from_bytes(&response) + .map(TryInto::try_into) + .map_err(|e| err!(Database("Invalid event_id in txnid data: {e:?}")))??, }); } let mut unsigned = BTreeMap::new(); unsigned.insert("transaction_id".to_owned(), body.txn_id.to_string().into()); + let content = from_str(body.body.body.json().get()) + .map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Invalid JSON body."))?; + let event_id = services .rooms .timeline .build_and_append_pdu( PduBuilder { event_type: body.event_type.to_string().into(), - content: from_str(body.body.body.json().get()) - .map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Invalid JSON body."))?, + content, unsigned: Some(unsigned), state_key: None, redacts: None, - timestamp: if body.appservice_info.is_some() { - body.timestamp - } else { - None - }, + timestamp: appservice_info.and(body.timestamp), }, sender_user, &body.room_id, &state_lock, ) - .await?; + .await + .map(|event_id| (*event_id).to_owned())?; services .transaction_ids - .add_txnid(sender_user, sender_device, &body.txn_id, event_id.as_bytes())?; + .add_txnid(sender_user, sender_device, &body.txn_id, event_id.as_bytes()); drop(state_lock); - Ok(send_message_event::v3::Response::new((*event_id).to_owned())) + Ok(send_message_event::v3::Response { + event_id, + }) } /// # `GET /_matrix/client/r0/rooms/{roomId}/messages` @@ -117,8 +118,12 @@ pub(crate) async fn get_message_events_route( let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_device = body.sender_device.as_ref().expect("user is authenticated"); - let from = match body.from.clone() { - Some(from) => PduCount::try_from_string(&from)?, + let room_id = &body.room_id; + let filter = &body.filter; + + let limit = usize::try_from(body.limit).unwrap_or(10).min(100); + let from = match body.from.as_ref() { + Some(from) => PduCount::try_from_string(from)?, None => match body.dir { ruma::api::Direction::Forward => PduCount::min(), ruma::api::Direction::Backward => PduCount::max(), @@ -133,30 +138,25 @@ pub(crate) async fn get_message_events_route( services .rooms .lazy_loading - .lazy_load_confirm_delivery(sender_user, sender_device, &body.room_id, from) - .await?; - - let limit = usize::try_from(body.limit).unwrap_or(10).min(100); - - let next_token; + .lazy_load_confirm_delivery(sender_user, sender_device, room_id, from); let mut resp = get_message_events::v3::Response::new(); - let mut lazy_loaded = HashSet::new(); - + let next_token; match body.dir { ruma::api::Direction::Forward => { - let events_after: Vec<_> = services + let events_after: Vec = services .rooms .timeline - .pdus_after(sender_user, &body.room_id, from)? - .filter_map(Result::ok) // Filter out buggy events - .filter(|(_, pdu)| { contains_url_filter(pdu, &body.filter) && visibility_filter(&services, pdu, sender_user, &body.room_id) - - }) - .take_while(|&(k, _)| Some(k) != to) // Stop at `to` + .pdus_after(sender_user, room_id, from) + .await? + .ready_filter_map(|item| contains_url_filter(item, filter)) + .filter_map(|item| visibility_filter(&services, item, sender_user)) + .ready_take_while(|(count, _)| Some(*count) != to) // Stop at `to` .take(limit) - .collect(); + .collect() + .boxed() + .await; for (_, event) in &events_after { /* TODO: Remove the not "element_hacks" check when these are resolved: @@ -164,16 +164,18 @@ pub(crate) async fn get_message_events_route( * https://github.com/vector-im/element-web/issues/21034 */ if !cfg!(feature = "element_hacks") - && !services.rooms.lazy_loading.lazy_load_was_sent_before( - sender_user, - sender_device, - &body.room_id, - &event.sender, - )? { + && !services + .rooms + .lazy_loading + .lazy_load_was_sent_before(sender_user, sender_device, room_id, &event.sender) + .await + { lazy_loaded.insert(event.sender.clone()); } - lazy_loaded.insert(event.sender.clone()); + if cfg!(features = "element_hacks") { + lazy_loaded.insert(event.sender.clone()); + } } next_token = events_after.last().map(|(count, _)| count).copied(); @@ -191,17 +193,22 @@ pub(crate) async fn get_message_events_route( services .rooms .timeline - .backfill_if_required(&body.room_id, from) + .backfill_if_required(room_id, from) + .boxed() .await?; - let events_before: Vec<_> = services + + let events_before: Vec = services .rooms .timeline - .pdus_until(sender_user, &body.room_id, from)? - .filter_map(Result::ok) // Filter out buggy events - .filter(|(_, pdu)| {contains_url_filter(pdu, &body.filter) && visibility_filter(&services, pdu, sender_user, &body.room_id)}) - .take_while(|&(k, _)| Some(k) != to) // Stop at `to` + .pdus_until(sender_user, room_id, from) + .await? + .ready_filter_map(|item| contains_url_filter(item, filter)) + .filter_map(|item| visibility_filter(&services, item, sender_user)) + .ready_take_while(|(count, _)| Some(*count) != to) // Stop at `to` .take(limit) - .collect(); + .collect() + .boxed() + .await; for (_, event) in &events_before { /* TODO: Remove the not "element_hacks" check when these are resolved: @@ -209,16 +216,18 @@ pub(crate) async fn get_message_events_route( * https://github.com/vector-im/element-web/issues/21034 */ if !cfg!(feature = "element_hacks") - && !services.rooms.lazy_loading.lazy_load_was_sent_before( - sender_user, - sender_device, - &body.room_id, - &event.sender, - )? { + && !services + .rooms + .lazy_loading + .lazy_load_was_sent_before(sender_user, sender_device, room_id, &event.sender) + .await + { lazy_loaded.insert(event.sender.clone()); } - lazy_loaded.insert(event.sender.clone()); + if cfg!(features = "element_hacks") { + lazy_loaded.insert(event.sender.clone()); + } } next_token = events_before.last().map(|(count, _)| count).copied(); @@ -236,11 +245,11 @@ pub(crate) async fn get_message_events_route( resp.state = Vec::new(); for ll_id in &lazy_loaded { - if let Some(member_event) = - services - .rooms - .state_accessor - .room_state_get(&body.room_id, &StateEventType::RoomMember, ll_id.as_str())? + if let Ok(member_event) = services + .rooms + .state_accessor + .room_state_get(room_id, &StateEventType::RoomMember, ll_id.as_str()) + .await { resp.state.push(member_event.to_state_event()); } @@ -249,34 +258,43 @@ pub(crate) async fn get_message_events_route( // remove the feature check when we are sure clients like element can handle it if !cfg!(feature = "element_hacks") { if let Some(next_token) = next_token { - services - .rooms - .lazy_loading - .lazy_load_mark_sent(sender_user, sender_device, &body.room_id, lazy_loaded, next_token) - .await; + services.rooms.lazy_loading.lazy_load_mark_sent( + sender_user, + sender_device, + room_id, + lazy_loaded, + next_token, + ); } } Ok(resp) } -fn visibility_filter(services: &Services, pdu: &PduEvent, user_id: &UserId, room_id: &RoomId) -> bool { +async fn visibility_filter(services: &Services, item: PdusIterItem, user_id: &UserId) -> Option { + let (_, pdu) = &item; + services .rooms .state_accessor - .user_can_see_event(user_id, room_id, &pdu.event_id) - .unwrap_or(false) + .user_can_see_event(user_id, &pdu.room_id, &pdu.event_id) + .await + .then_some(item) } -fn contains_url_filter(pdu: &PduEvent, filter: &RoomEventFilter) -> bool { +fn contains_url_filter(item: PdusIterItem, filter: &RoomEventFilter) -> Option { + let (_, pdu) = &item; + if filter.url_filter.is_none() { - return true; + return Some(item); } let content: Value = from_str(pdu.content.get()).unwrap(); - match filter.url_filter { + let res = match filter.url_filter { Some(UrlFilter::EventsWithoutUrl) => !content["url"].is_string(), Some(UrlFilter::EventsWithUrl) => content["url"].is_string(), None => true, - } + }; + + res.then_some(item) } diff --git a/src/api/client/presence.rs b/src/api/client/presence.rs index 8384d5aca..ba48808bd 100644 --- a/src/api/client/presence.rs +++ b/src/api/client/presence.rs @@ -28,7 +28,8 @@ pub(crate) async fn set_presence_route( services .presence - .set_presence(sender_user, &body.presence, None, None, body.status_msg.clone())?; + .set_presence(sender_user, &body.presence, None, None, body.status_msg.clone()) + .await?; Ok(set_presence::v3::Response {}) } @@ -49,14 +50,15 @@ pub(crate) async fn get_presence_route( let mut presence_event = None; - for _room_id in services + let has_shared_rooms = services .rooms .user - .get_shared_rooms(vec![sender_user.clone(), body.user_id.clone()])? - { - if let Some(presence) = services.presence.get_presence(&body.user_id)? { + .has_shared_rooms(sender_user, &body.user_id) + .await; + + if has_shared_rooms { + if let Ok(presence) = services.presence.get_presence(&body.user_id).await { presence_event = Some(presence); - break; } } diff --git a/src/api/client/profile.rs b/src/api/client/profile.rs index bf47a3f85..495bc8ec3 100644 --- a/src/api/client/profile.rs +++ b/src/api/client/profile.rs @@ -1,5 +1,10 @@ use axum::extract::State; -use conduit::{pdu::PduBuilder, warn, Err, Error, Result}; +use conduit::{ + pdu::PduBuilder, + utils::{stream::TryIgnore, IterStream}, + warn, Err, Error, Result, +}; +use futures::{StreamExt, TryStreamExt}; use ruma::{ api::{ client::{ @@ -35,16 +40,18 @@ pub(crate) async fn set_displayname_route( .rooms .state_cache .rooms_joined(&body.user_id) - .filter_map(Result::ok) - .collect(); + .map(ToOwned::to_owned) + .collect() + .await; - update_displayname(&services, &body.user_id, body.displayname.clone(), all_joined_rooms).await?; + update_displayname(&services, &body.user_id, body.displayname.clone(), &all_joined_rooms).await?; if services.globals.allow_local_presence() { // Presence update services .presence - .ping_presence(&body.user_id, &PresenceState::Online)?; + .ping_presence(&body.user_id, &PresenceState::Online) + .await?; } Ok(set_display_name::v3::Response {}) @@ -72,22 +79,19 @@ pub(crate) async fn get_displayname_route( ) .await { - if !services.users.exists(&body.user_id)? { + if !services.users.exists(&body.user_id).await { services.users.create(&body.user_id, None)?; } services .users - .set_displayname(&body.user_id, response.displayname.clone()) - .await?; + .set_displayname(&body.user_id, response.displayname.clone()); services .users - .set_avatar_url(&body.user_id, response.avatar_url.clone()) - .await?; + .set_avatar_url(&body.user_id, response.avatar_url.clone()); services .users - .set_blurhash(&body.user_id, response.blurhash.clone()) - .await?; + .set_blurhash(&body.user_id, response.blurhash.clone()); return Ok(get_display_name::v3::Response { displayname: response.displayname, @@ -95,14 +99,14 @@ pub(crate) async fn get_displayname_route( } } - if !services.users.exists(&body.user_id)? { + if !services.users.exists(&body.user_id).await { // Return 404 if this user doesn't exist and we couldn't fetch it over // federation return Err(Error::BadRequest(ErrorKind::NotFound, "Profile was not found.")); } Ok(get_display_name::v3::Response { - displayname: services.users.displayname(&body.user_id)?, + displayname: services.users.displayname(&body.user_id).await.ok(), }) } @@ -124,15 +128,16 @@ pub(crate) async fn set_avatar_url_route( .rooms .state_cache .rooms_joined(&body.user_id) - .filter_map(Result::ok) - .collect(); + .map(ToOwned::to_owned) + .collect() + .await; update_avatar_url( &services, &body.user_id, body.avatar_url.clone(), body.blurhash.clone(), - all_joined_rooms, + &all_joined_rooms, ) .await?; @@ -140,7 +145,9 @@ pub(crate) async fn set_avatar_url_route( // Presence update services .presence - .ping_presence(&body.user_id, &PresenceState::Online)?; + .ping_presence(&body.user_id, &PresenceState::Online) + .await + .ok(); } Ok(set_avatar_url::v3::Response {}) @@ -168,22 +175,21 @@ pub(crate) async fn get_avatar_url_route( ) .await { - if !services.users.exists(&body.user_id)? { + if !services.users.exists(&body.user_id).await { services.users.create(&body.user_id, None)?; } services .users - .set_displayname(&body.user_id, response.displayname.clone()) - .await?; + .set_displayname(&body.user_id, response.displayname.clone()); + services .users - .set_avatar_url(&body.user_id, response.avatar_url.clone()) - .await?; + .set_avatar_url(&body.user_id, response.avatar_url.clone()); + services .users - .set_blurhash(&body.user_id, response.blurhash.clone()) - .await?; + .set_blurhash(&body.user_id, response.blurhash.clone()); return Ok(get_avatar_url::v3::Response { avatar_url: response.avatar_url, @@ -192,15 +198,15 @@ pub(crate) async fn get_avatar_url_route( } } - if !services.users.exists(&body.user_id)? { + if !services.users.exists(&body.user_id).await { // Return 404 if this user doesn't exist and we couldn't fetch it over // federation return Err(Error::BadRequest(ErrorKind::NotFound, "Profile was not found.")); } Ok(get_avatar_url::v3::Response { - avatar_url: services.users.avatar_url(&body.user_id)?, - blurhash: services.users.blurhash(&body.user_id)?, + avatar_url: services.users.avatar_url(&body.user_id).await.ok(), + blurhash: services.users.blurhash(&body.user_id).await.ok(), }) } @@ -226,31 +232,30 @@ pub(crate) async fn get_profile_route( ) .await { - if !services.users.exists(&body.user_id)? { + if !services.users.exists(&body.user_id).await { services.users.create(&body.user_id, None)?; } services .users - .set_displayname(&body.user_id, response.displayname.clone()) - .await?; + .set_displayname(&body.user_id, response.displayname.clone()); + services .users - .set_avatar_url(&body.user_id, response.avatar_url.clone()) - .await?; + .set_avatar_url(&body.user_id, response.avatar_url.clone()); + services .users - .set_blurhash(&body.user_id, response.blurhash.clone()) - .await?; + .set_blurhash(&body.user_id, response.blurhash.clone()); + services .users - .set_timezone(&body.user_id, response.tz.clone()) - .await?; + .set_timezone(&body.user_id, response.tz.clone()); for (profile_key, profile_key_value) in &response.custom_profile_fields { services .users - .set_profile_key(&body.user_id, profile_key, Some(profile_key_value.clone()))?; + .set_profile_key(&body.user_id, profile_key, Some(profile_key_value.clone())); } return Ok(get_profile::v3::Response { @@ -263,104 +268,93 @@ pub(crate) async fn get_profile_route( } } - if !services.users.exists(&body.user_id)? { + if !services.users.exists(&body.user_id).await { // Return 404 if this user doesn't exist and we couldn't fetch it over // federation return Err(Error::BadRequest(ErrorKind::NotFound, "Profile was not found.")); } Ok(get_profile::v3::Response { - avatar_url: services.users.avatar_url(&body.user_id)?, - blurhash: services.users.blurhash(&body.user_id)?, - displayname: services.users.displayname(&body.user_id)?, - tz: services.users.timezone(&body.user_id)?, + avatar_url: services.users.avatar_url(&body.user_id).await.ok(), + blurhash: services.users.blurhash(&body.user_id).await.ok(), + displayname: services.users.displayname(&body.user_id).await.ok(), + tz: services.users.timezone(&body.user_id).await.ok(), custom_profile_fields: services .users .all_profile_keys(&body.user_id) - .filter_map(Result::ok) - .collect(), + .collect() + .await, }) } pub async fn update_displayname( - services: &Services, user_id: &UserId, displayname: Option, all_joined_rooms: Vec, + services: &Services, user_id: &UserId, displayname: Option, all_joined_rooms: &[OwnedRoomId], ) -> Result<()> { - let current_display_name = services.users.displayname(user_id).unwrap_or_default(); + let current_display_name = services.users.displayname(user_id).await.ok(); if displayname == current_display_name { return Ok(()); } - services - .users - .set_displayname(user_id, displayname.clone()) - .await?; + services.users.set_displayname(user_id, displayname.clone()); // Send a new join membership event into all joined rooms - let all_joined_rooms: Vec<_> = all_joined_rooms - .iter() - .map(|room_id| { - Ok::<_, Error>(( - PduBuilder { - event_type: TimelineEventType::RoomMember, - content: to_raw_value(&RoomMemberEventContent { - displayname: displayname.clone(), - join_authorized_via_users_server: None, - ..serde_json::from_str( - services - .rooms - .state_accessor - .room_state_get(room_id, &StateEventType::RoomMember, user_id.as_str())? - .ok_or_else(|| { - Error::bad_database("Tried to send display name update for user not in the room.") - })? - .content - .get(), - ) - .map_err(|_| Error::bad_database("Database contains invalid PDU."))? - }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(user_id.to_string()), - redacts: None, - timestamp: None, - }, - room_id, - )) - }) - .filter_map(Result::ok) - .collect(); + let mut joined_rooms = Vec::new(); + for room_id in all_joined_rooms { + let Ok(event) = services + .rooms + .state_accessor + .room_state_get(room_id, &StateEventType::RoomMember, user_id.as_str()) + .await + else { + continue; + }; + + let pdu = PduBuilder { + event_type: TimelineEventType::RoomMember, + content: to_raw_value(&RoomMemberEventContent { + displayname: displayname.clone(), + join_authorized_via_users_server: None, + ..serde_json::from_str(event.content.get()).expect("Database contains invalid PDU.") + }) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: Some(user_id.to_string()), + redacts: None, + timestamp: None, + }; + + joined_rooms.push((pdu, room_id)); + } - update_all_rooms(services, all_joined_rooms, user_id).await; + update_all_rooms(services, joined_rooms, user_id).await; Ok(()) } pub async fn update_avatar_url( services: &Services, user_id: &UserId, avatar_url: Option, blurhash: Option, - all_joined_rooms: Vec, + all_joined_rooms: &[OwnedRoomId], ) -> Result<()> { - let current_avatar_url = services.users.avatar_url(user_id).unwrap_or_default(); - let current_blurhash = services.users.blurhash(user_id).unwrap_or_default(); + let current_avatar_url = services.users.avatar_url(user_id).await.ok(); + let current_blurhash = services.users.blurhash(user_id).await.ok(); if current_avatar_url == avatar_url && current_blurhash == blurhash { return Ok(()); } - services - .users - .set_avatar_url(user_id, avatar_url.clone()) - .await?; - services - .users - .set_blurhash(user_id, blurhash.clone()) - .await?; + services.users.set_avatar_url(user_id, avatar_url.clone()); + + services.users.set_blurhash(user_id, blurhash.clone()); // Send a new join membership event into all joined rooms + let avatar_url = &avatar_url; + let blurhash = &blurhash; let all_joined_rooms: Vec<_> = all_joined_rooms .iter() - .map(|room_id| { - Ok::<_, Error>(( + .try_stream() + .and_then(|room_id: &OwnedRoomId| async move { + Ok(( PduBuilder { event_type: TimelineEventType::RoomMember, content: to_raw_value(&RoomMemberEventContent { @@ -371,8 +365,9 @@ pub async fn update_avatar_url( services .rooms .state_accessor - .room_state_get(room_id, &StateEventType::RoomMember, user_id.as_str())? - .ok_or_else(|| { + .room_state_get(room_id, &StateEventType::RoomMember, user_id.as_str()) + .await + .map_err(|_| { Error::bad_database("Tried to send avatar URL update for user not in the room.") })? .content @@ -389,8 +384,9 @@ pub async fn update_avatar_url( room_id, )) }) - .filter_map(Result::ok) - .collect(); + .ignore_err() + .collect() + .await; update_all_rooms(services, all_joined_rooms, user_id).await; diff --git a/src/api/client/push.rs b/src/api/client/push.rs index 8723e676b..390951999 100644 --- a/src/api/client/push.rs +++ b/src/api/client/push.rs @@ -29,40 +29,36 @@ pub(crate) async fn get_pushrules_all_route( let global_ruleset: Ruleset; - let Ok(event) = - services - .account_data - .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into()) - else { - // push rules event doesn't exist, create it and return default - return recreate_push_rules_and_return(&services, sender_user); + let event = services + .account_data + .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into()) + .await; + + let Ok(event) = event else { + // user somehow has non-existent push rule event. recreate it and return server + // default silently + return recreate_push_rules_and_return(&services, sender_user).await; }; - if let Some(event) = event { - let value = serde_json::from_str::(event.get()) - .map_err(|e| err!(Database(warn!("Invalid push rules account data event in database: {e}"))))?; + let value = serde_json::from_str::(event.get()) + .map_err(|e| err!(Database(warn!("Invalid push rules account data event in database: {e}"))))?; - let Some(content_value) = value.get("content") else { - // user somehow has a push rule event with no content key, recreate it and - // return server default silently - return recreate_push_rules_and_return(&services, sender_user); - }; + let Some(content_value) = value.get("content") else { + // user somehow has a push rule event with no content key, recreate it and + // return server default silently + return recreate_push_rules_and_return(&services, sender_user).await; + }; - if content_value.to_string().is_empty() { - // user somehow has a push rule event with empty content, recreate it and return - // server default silently - return recreate_push_rules_and_return(&services, sender_user); - } + if content_value.to_string().is_empty() { + // user somehow has a push rule event with empty content, recreate it and return + // server default silently + return recreate_push_rules_and_return(&services, sender_user).await; + } - let account_data_content = serde_json::from_value::(content_value.clone().into()) - .map_err(|e| err!(Database(warn!("Invalid push rules account data event in database: {e}"))))?; + let account_data_content = serde_json::from_value::(content_value.clone().into()) + .map_err(|e| err!(Database(warn!("Invalid push rules account data event in database: {e}"))))?; - global_ruleset = account_data_content.global; - } else { - // user somehow has non-existent push rule event. recreate it and return server - // default silently - return recreate_push_rules_and_return(&services, sender_user); - } + global_ruleset = account_data_content.global; Ok(get_pushrules_all::v3::Response { global: global_ruleset, @@ -79,8 +75,9 @@ pub(crate) async fn get_pushrule_route( let event = services .account_data - .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())? - .ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; + .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into()) + .await + .map_err(|_| Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; let account_data = serde_json::from_str::(event.get()) .map_err(|_| Error::bad_database("Invalid account data event in db."))? @@ -118,8 +115,9 @@ pub(crate) async fn set_pushrule_route( let event = services .account_data - .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())? - .ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; + .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into()) + .await + .map_err(|_| Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; let mut account_data = serde_json::from_str::(event.get()) .map_err(|_| Error::bad_database("Invalid account data event in db."))?; @@ -155,12 +153,15 @@ pub(crate) async fn set_pushrule_route( return Err(err); } - services.account_data.update( - None, - sender_user, - GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(account_data).expect("to json value always works"), - )?; + services + .account_data + .update( + None, + sender_user, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(account_data).expect("to json value always works"), + ) + .await?; Ok(set_pushrule::v3::Response {}) } @@ -182,8 +183,9 @@ pub(crate) async fn get_pushrule_actions_route( let event = services .account_data - .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())? - .ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; + .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into()) + .await + .map_err(|_| Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; let account_data = serde_json::from_str::(event.get()) .map_err(|_| Error::bad_database("Invalid account data event in db."))? @@ -217,8 +219,9 @@ pub(crate) async fn set_pushrule_actions_route( let event = services .account_data - .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())? - .ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; + .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into()) + .await + .map_err(|_| Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; let mut account_data = serde_json::from_str::(event.get()) .map_err(|_| Error::bad_database("Invalid account data event in db."))?; @@ -232,12 +235,15 @@ pub(crate) async fn set_pushrule_actions_route( return Err(Error::BadRequest(ErrorKind::NotFound, "Push rule not found.")); } - services.account_data.update( - None, - sender_user, - GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(account_data).expect("to json value always works"), - )?; + services + .account_data + .update( + None, + sender_user, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(account_data).expect("to json value always works"), + ) + .await?; Ok(set_pushrule_actions::v3::Response {}) } @@ -259,8 +265,9 @@ pub(crate) async fn get_pushrule_enabled_route( let event = services .account_data - .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())? - .ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; + .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into()) + .await + .map_err(|_| Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; let account_data = serde_json::from_str::(event.get()) .map_err(|_| Error::bad_database("Invalid account data event in db."))?; @@ -293,8 +300,9 @@ pub(crate) async fn set_pushrule_enabled_route( let event = services .account_data - .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())? - .ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; + .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into()) + .await + .map_err(|_| Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; let mut account_data = serde_json::from_str::(event.get()) .map_err(|_| Error::bad_database("Invalid account data event in db."))?; @@ -308,12 +316,15 @@ pub(crate) async fn set_pushrule_enabled_route( return Err(Error::BadRequest(ErrorKind::NotFound, "Push rule not found.")); } - services.account_data.update( - None, - sender_user, - GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(account_data).expect("to json value always works"), - )?; + services + .account_data + .update( + None, + sender_user, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(account_data).expect("to json value always works"), + ) + .await?; Ok(set_pushrule_enabled::v3::Response {}) } @@ -335,8 +346,9 @@ pub(crate) async fn delete_pushrule_route( let event = services .account_data - .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())? - .ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; + .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into()) + .await + .map_err(|_| Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; let mut account_data = serde_json::from_str::(event.get()) .map_err(|_| Error::bad_database("Invalid account data event in db."))?; @@ -357,12 +369,15 @@ pub(crate) async fn delete_pushrule_route( return Err(err); } - services.account_data.update( - None, - sender_user, - GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(account_data).expect("to json value always works"), - )?; + services + .account_data + .update( + None, + sender_user, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(account_data).expect("to json value always works"), + ) + .await?; Ok(delete_pushrule::v3::Response {}) } @@ -376,7 +391,7 @@ pub(crate) async fn get_pushers_route( let sender_user = body.sender_user.as_ref().expect("user is authenticated"); Ok(get_pushers::v3::Response { - pushers: services.pusher.get_pushers(sender_user)?, + pushers: services.pusher.get_pushers(sender_user).await, }) } @@ -390,27 +405,30 @@ pub(crate) async fn set_pushers_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - services.pusher.set_pusher(sender_user, &body.action)?; + services.pusher.set_pusher(sender_user, &body.action); Ok(set_pusher::v3::Response::default()) } /// user somehow has bad push rules, these must always exist per spec. /// so recreate it and return server default silently -fn recreate_push_rules_and_return( +async fn recreate_push_rules_and_return( services: &Services, sender_user: &ruma::UserId, ) -> Result { - services.account_data.update( - None, - sender_user, - GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(PushRulesEvent { - content: PushRulesEventContent { - global: Ruleset::server_default(sender_user), - }, - }) - .expect("to json always works"), - )?; + services + .account_data + .update( + None, + sender_user, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(PushRulesEvent { + content: PushRulesEventContent { + global: Ruleset::server_default(sender_user), + }, + }) + .expect("to json always works"), + ) + .await?; Ok(get_pushrules_all::v3::Response { global: Ruleset::server_default(sender_user), diff --git a/src/api/client/read_marker.rs b/src/api/client/read_marker.rs index f40f24932..f28b2aec5 100644 --- a/src/api/client/read_marker.rs +++ b/src/api/client/read_marker.rs @@ -31,27 +31,32 @@ pub(crate) async fn set_read_marker_route( event_id: fully_read.clone(), }, }; - services.account_data.update( - Some(&body.room_id), - sender_user, - RoomAccountDataEventType::FullyRead, - &serde_json::to_value(fully_read_event).expect("to json value always works"), - )?; + services + .account_data + .update( + Some(&body.room_id), + sender_user, + RoomAccountDataEventType::FullyRead, + &serde_json::to_value(fully_read_event).expect("to json value always works"), + ) + .await?; } if body.private_read_receipt.is_some() || body.read_receipt.is_some() { services .rooms .user - .reset_notification_counts(sender_user, &body.room_id)?; + .reset_notification_counts(sender_user, &body.room_id); } if let Some(event) = &body.private_read_receipt { let count = services .rooms .timeline - .get_pdu_count(event)? - .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Event does not exist."))?; + .get_pdu_count(event) + .await + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Event does not exist."))?; + let count = match count { PduCount::Backfilled(_) => { return Err(Error::BadRequest( @@ -64,7 +69,7 @@ pub(crate) async fn set_read_marker_route( services .rooms .read_receipt - .private_read_set(&body.room_id, sender_user, count)?; + .private_read_set(&body.room_id, sender_user, count); } if let Some(event) = &body.read_receipt { @@ -83,14 +88,18 @@ pub(crate) async fn set_read_marker_route( let mut receipt_content = BTreeMap::new(); receipt_content.insert(event.to_owned(), receipts); - services.rooms.read_receipt.readreceipt_update( - sender_user, - &body.room_id, - &ruma::events::receipt::ReceiptEvent { - content: ruma::events::receipt::ReceiptEventContent(receipt_content), - room_id: body.room_id.clone(), - }, - )?; + services + .rooms + .read_receipt + .readreceipt_update( + sender_user, + &body.room_id, + &ruma::events::receipt::ReceiptEvent { + content: ruma::events::receipt::ReceiptEventContent(receipt_content), + room_id: body.room_id.clone(), + }, + ) + .await; } Ok(set_read_marker::v3::Response {}) @@ -111,7 +120,7 @@ pub(crate) async fn create_receipt_route( services .rooms .user - .reset_notification_counts(sender_user, &body.room_id)?; + .reset_notification_counts(sender_user, &body.room_id); } match body.receipt_type { @@ -121,12 +130,15 @@ pub(crate) async fn create_receipt_route( event_id: body.event_id.clone(), }, }; - services.account_data.update( - Some(&body.room_id), - sender_user, - RoomAccountDataEventType::FullyRead, - &serde_json::to_value(fully_read_event).expect("to json value always works"), - )?; + services + .account_data + .update( + Some(&body.room_id), + sender_user, + RoomAccountDataEventType::FullyRead, + &serde_json::to_value(fully_read_event).expect("to json value always works"), + ) + .await?; }, create_receipt::v3::ReceiptType::Read => { let mut user_receipts = BTreeMap::new(); @@ -143,21 +155,27 @@ pub(crate) async fn create_receipt_route( let mut receipt_content = BTreeMap::new(); receipt_content.insert(body.event_id.clone(), receipts); - services.rooms.read_receipt.readreceipt_update( - sender_user, - &body.room_id, - &ruma::events::receipt::ReceiptEvent { - content: ruma::events::receipt::ReceiptEventContent(receipt_content), - room_id: body.room_id.clone(), - }, - )?; + services + .rooms + .read_receipt + .readreceipt_update( + sender_user, + &body.room_id, + &ruma::events::receipt::ReceiptEvent { + content: ruma::events::receipt::ReceiptEventContent(receipt_content), + room_id: body.room_id.clone(), + }, + ) + .await; }, create_receipt::v3::ReceiptType::ReadPrivate => { let count = services .rooms .timeline - .get_pdu_count(&body.event_id)? - .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Event does not exist."))?; + .get_pdu_count(&body.event_id) + .await + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Event does not exist."))?; + let count = match count { PduCount::Backfilled(_) => { return Err(Error::BadRequest( @@ -170,7 +188,7 @@ pub(crate) async fn create_receipt_route( services .rooms .read_receipt - .private_read_set(&body.room_id, sender_user, count)?; + .private_read_set(&body.room_id, sender_user, count); }, _ => return Err(Error::bad_database("Unsupported receipt type")), } diff --git a/src/api/client/relations.rs b/src/api/client/relations.rs index ae6459400..d43847300 100644 --- a/src/api/client/relations.rs +++ b/src/api/client/relations.rs @@ -9,20 +9,24 @@ use crate::{Result, Ruma}; pub(crate) async fn get_relating_events_with_rel_type_and_event_type_route( State(services): State, body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let sender_user = body.sender_user.as_deref().expect("user is authenticated"); - let res = services.rooms.pdu_metadata.paginate_relations_with_filter( - sender_user, - &body.room_id, - &body.event_id, - &Some(body.event_type.clone()), - &Some(body.rel_type.clone()), - &body.from, - &body.to, - &body.limit, - body.recurse, - body.dir, - )?; + let res = services + .rooms + .pdu_metadata + .paginate_relations_with_filter( + sender_user, + &body.room_id, + &body.event_id, + body.event_type.clone().into(), + body.rel_type.clone().into(), + body.from.as_ref(), + body.to.as_ref(), + body.limit, + body.recurse, + body.dir, + ) + .await?; Ok(get_relating_events_with_rel_type_and_event_type::v1::Response { chunk: res.chunk, @@ -36,20 +40,24 @@ pub(crate) async fn get_relating_events_with_rel_type_and_event_type_route( pub(crate) async fn get_relating_events_with_rel_type_route( State(services): State, body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let sender_user = body.sender_user.as_deref().expect("user is authenticated"); - let res = services.rooms.pdu_metadata.paginate_relations_with_filter( - sender_user, - &body.room_id, - &body.event_id, - &None, - &Some(body.rel_type.clone()), - &body.from, - &body.to, - &body.limit, - body.recurse, - body.dir, - )?; + let res = services + .rooms + .pdu_metadata + .paginate_relations_with_filter( + sender_user, + &body.room_id, + &body.event_id, + None, + body.rel_type.clone().into(), + body.from.as_ref(), + body.to.as_ref(), + body.limit, + body.recurse, + body.dir, + ) + .await?; Ok(get_relating_events_with_rel_type::v1::Response { chunk: res.chunk, @@ -63,18 +71,22 @@ pub(crate) async fn get_relating_events_with_rel_type_route( pub(crate) async fn get_relating_events_route( State(services): State, body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let sender_user = body.sender_user.as_deref().expect("user is authenticated"); - services.rooms.pdu_metadata.paginate_relations_with_filter( - sender_user, - &body.room_id, - &body.event_id, - &None, - &None, - &body.from, - &body.to, - &body.limit, - body.recurse, - body.dir, - ) + services + .rooms + .pdu_metadata + .paginate_relations_with_filter( + sender_user, + &body.room_id, + &body.event_id, + None, + None, + body.from.as_ref(), + body.to.as_ref(), + body.limit, + body.recurse, + body.dir, + ) + .await } diff --git a/src/api/client/report.rs b/src/api/client/report.rs index 588bd3686..a40c35a28 100644 --- a/src/api/client/report.rs +++ b/src/api/client/report.rs @@ -1,6 +1,7 @@ use std::time::Duration; use axum::extract::State; +use conduit::{utils::ReadyExt, Err}; use rand::Rng; use ruma::{ api::client::{error::ErrorKind, room::report_content}, @@ -34,11 +35,8 @@ pub(crate) async fn report_event_route( delay_response().await; // check if we know about the reported event ID or if it's invalid - let Some(pdu) = services.rooms.timeline.get_pdu(&body.event_id)? else { - return Err(Error::BadRequest( - ErrorKind::NotFound, - "Event ID is not known to us or Event ID is invalid", - )); + let Ok(pdu) = services.rooms.timeline.get_pdu(&body.event_id).await else { + return Err!(Request(NotFound("Event ID is not known to us or Event ID is invalid"))); }; is_report_valid( @@ -49,7 +47,8 @@ pub(crate) async fn report_event_route( &body.reason, body.score, &pdu, - )?; + ) + .await?; // send admin room message that we received the report with an @room ping for // urgency @@ -81,7 +80,8 @@ pub(crate) async fn report_event_route( HtmlEscape(body.reason.as_deref().unwrap_or("")) ), )) - .await; + .await + .ok(); Ok(report_content::v3::Response {}) } @@ -92,7 +92,7 @@ pub(crate) async fn report_event_route( /// check if score is in valid range /// check if report reasoning is less than or equal to 750 characters /// check if reporting user is in the reporting room -fn is_report_valid( +async fn is_report_valid( services: &Services, event_id: &EventId, room_id: &RoomId, sender_user: &UserId, reason: &Option, score: Option, pdu: &std::sync::Arc, ) -> Result<()> { @@ -123,8 +123,8 @@ fn is_report_valid( .rooms .state_cache .room_members(room_id) - .filter_map(Result::ok) - .any(|user_id| user_id == *sender_user) + .ready_any(|user_id| user_id == sender_user) + .await { return Err(Error::BadRequest( ErrorKind::NotFound, diff --git a/src/api/client/room.rs b/src/api/client/room.rs index 0112e76dc..1edf85d80 100644 --- a/src/api/client/room.rs +++ b/src/api/client/room.rs @@ -2,6 +2,7 @@ use std::{cmp::max, collections::BTreeMap}; use axum::extract::State; use conduit::{debug_info, debug_warn, err, Err}; +use futures::{FutureExt, StreamExt}; use ruma::{ api::client::{ error::ErrorKind, @@ -74,7 +75,7 @@ pub(crate) async fn create_room_route( if !services.globals.allow_room_creation() && body.appservice_info.is_none() - && !services.users.is_admin(sender_user)? + && !services.users.is_admin(sender_user).await { return Err(Error::BadRequest(ErrorKind::forbidden(), "Room creation has been disabled.")); } @@ -86,7 +87,7 @@ pub(crate) async fn create_room_route( }; // check if room ID doesn't already exist instead of erroring on auth check - if services.rooms.short.get_shortroomid(&room_id)?.is_some() { + if services.rooms.short.get_shortroomid(&room_id).await.is_ok() { return Err(Error::BadRequest( ErrorKind::RoomInUse, "Room with that custom room ID already exists", @@ -95,7 +96,7 @@ pub(crate) async fn create_room_route( if body.visibility == room::Visibility::Public && services.globals.config.lockdown_public_room_directory - && !services.users.is_admin(sender_user)? + && !services.users.is_admin(sender_user).await && body.appservice_info.is_none() { info!( @@ -118,7 +119,11 @@ pub(crate) async fn create_room_route( return Err!(Request(Forbidden("Publishing rooms to the room directory is not allowed"))); } - let _short_id = services.rooms.short.get_or_create_shortroomid(&room_id)?; + let _short_id = services + .rooms + .short + .get_or_create_shortroomid(&room_id) + .await; let state_lock = services.rooms.state.mutex.lock(&room_id).await; let alias: Option = if let Some(alias) = &body.room_alias_name { @@ -218,6 +223,7 @@ pub(crate) async fn create_room_route( &room_id, &state_lock, ) + .boxed() .await?; // 2. Let the room creator join @@ -229,11 +235,11 @@ pub(crate) async fn create_room_route( event_type: TimelineEventType::RoomMember, content: to_raw_value(&RoomMemberEventContent { membership: MembershipState::Join, - displayname: services.users.displayname(sender_user)?, - avatar_url: services.users.avatar_url(sender_user)?, + displayname: services.users.displayname(sender_user).await.ok(), + avatar_url: services.users.avatar_url(sender_user).await.ok(), is_direct: Some(body.is_direct), third_party_invite: None, - blurhash: services.users.blurhash(sender_user)?, + blurhash: services.users.blurhash(sender_user).await.ok(), reason: None, join_authorized_via_users_server: None, }) @@ -247,6 +253,7 @@ pub(crate) async fn create_room_route( &room_id, &state_lock, ) + .boxed() .await?; // 3. Power levels @@ -284,6 +291,7 @@ pub(crate) async fn create_room_route( &room_id, &state_lock, ) + .boxed() .await?; // 4. Canonical room alias @@ -308,6 +316,7 @@ pub(crate) async fn create_room_route( &room_id, &state_lock, ) + .boxed() .await?; } @@ -335,6 +344,7 @@ pub(crate) async fn create_room_route( &room_id, &state_lock, ) + .boxed() .await?; // 5.2 History Visibility @@ -355,6 +365,7 @@ pub(crate) async fn create_room_route( &room_id, &state_lock, ) + .boxed() .await?; // 5.3 Guest Access @@ -378,6 +389,7 @@ pub(crate) async fn create_room_route( &room_id, &state_lock, ) + .boxed() .await?; // 6. Events listed in initial_state @@ -410,6 +422,7 @@ pub(crate) async fn create_room_route( .rooms .timeline .build_and_append_pdu(pdu_builder, sender_user, &room_id, &state_lock) + .boxed() .await?; } @@ -432,6 +445,7 @@ pub(crate) async fn create_room_route( &room_id, &state_lock, ) + .boxed() .await?; } @@ -455,13 +469,17 @@ pub(crate) async fn create_room_route( &room_id, &state_lock, ) + .boxed() .await?; } // 8. Events implied by invite (and TODO: invite_3pid) drop(state_lock); for user_id in &body.invite { - if let Err(e) = invite_helper(&services, sender_user, user_id, &room_id, None, body.is_direct).await { + if let Err(e) = invite_helper(&services, sender_user, user_id, &room_id, None, body.is_direct) + .boxed() + .await + { warn!(%e, "Failed to send invite"); } } @@ -475,7 +493,7 @@ pub(crate) async fn create_room_route( } if body.visibility == room::Visibility::Public { - services.rooms.directory.set_public(&room_id)?; + services.rooms.directory.set_public(&room_id); if services.globals.config.admin_room_notices { services @@ -505,13 +523,15 @@ pub(crate) async fn get_room_event_route( let event = services .rooms .timeline - .get_pdu(&body.event_id)? - .ok_or_else(|| err!(Request(NotFound("Event {} not found.", &body.event_id))))?; + .get_pdu(&body.event_id) + .await + .map_err(|_| err!(Request(NotFound("Event {} not found.", &body.event_id))))?; if !services .rooms .state_accessor - .user_can_see_event(sender_user, &event.room_id, &body.event_id)? + .user_can_see_event(sender_user, &event.room_id, &body.event_id) + .await { return Err(Error::BadRequest( ErrorKind::forbidden(), @@ -541,7 +561,8 @@ pub(crate) async fn get_room_aliases_route( if !services .rooms .state_accessor - .user_can_see_state_events(sender_user, &body.room_id)? + .user_can_see_state_events(sender_user, &body.room_id) + .await { return Err(Error::BadRequest( ErrorKind::forbidden(), @@ -554,8 +575,9 @@ pub(crate) async fn get_room_aliases_route( .rooms .alias .local_aliases_for_room(&body.room_id) - .filter_map(Result::ok) - .collect(), + .map(ToOwned::to_owned) + .collect() + .await, }) } @@ -591,7 +613,8 @@ pub(crate) async fn upgrade_room_route( let _short_id = services .rooms .short - .get_or_create_shortroomid(&replacement_room)?; + .get_or_create_shortroomid(&replacement_room) + .await; let state_lock = services.rooms.state.mutex.lock(&body.room_id).await; @@ -629,12 +652,12 @@ pub(crate) async fn upgrade_room_route( services .rooms .state_accessor - .room_state_get(&body.room_id, &StateEventType::RoomCreate, "")? - .ok_or_else(|| Error::bad_database("Found room without m.room.create event."))? + .room_state_get(&body.room_id, &StateEventType::RoomCreate, "") + .await + .map_err(|_| err!(Database("Found room without m.room.create event.")))? .content .get(), - ) - .map_err(|_| Error::bad_database("Invalid room event in database."))?; + )?; // Use the m.room.tombstone event as the predecessor let predecessor = Some(ruma::events::room::create::PreviousRoom::new( @@ -714,11 +737,11 @@ pub(crate) async fn upgrade_room_route( event_type: TimelineEventType::RoomMember, content: to_raw_value(&RoomMemberEventContent { membership: MembershipState::Join, - displayname: services.users.displayname(sender_user)?, - avatar_url: services.users.avatar_url(sender_user)?, + displayname: services.users.displayname(sender_user).await.ok(), + avatar_url: services.users.avatar_url(sender_user).await.ok(), is_direct: None, third_party_invite: None, - blurhash: services.users.blurhash(sender_user)?, + blurhash: services.users.blurhash(sender_user).await.ok(), reason: None, join_authorized_via_users_server: None, }) @@ -739,10 +762,11 @@ pub(crate) async fn upgrade_room_route( let event_content = match services .rooms .state_accessor - .room_state_get(&body.room_id, event_type, "")? + .room_state_get(&body.room_id, event_type, "") + .await { - Some(v) => v.content.clone(), - None => continue, // Skipping missing events. + Ok(v) => v.content.clone(), + Err(_) => continue, // Skipping missing events. }; services @@ -765,21 +789,23 @@ pub(crate) async fn upgrade_room_route( } // Moves any local aliases to the new room - for alias in services + let mut local_aliases = services .rooms .alias .local_aliases_for_room(&body.room_id) - .filter_map(Result::ok) - { + .boxed(); + + while let Some(alias) = local_aliases.next().await { services .rooms .alias - .remove_alias(&alias, sender_user) + .remove_alias(alias, sender_user) .await?; + services .rooms .alias - .set_alias(&alias, &replacement_room, sender_user)?; + .set_alias(alias, &replacement_room, sender_user)?; } // Get the old room power levels @@ -787,12 +813,12 @@ pub(crate) async fn upgrade_room_route( services .rooms .state_accessor - .room_state_get(&body.room_id, &StateEventType::RoomPowerLevels, "")? - .ok_or_else(|| Error::bad_database("Found room without m.room.create event."))? + .room_state_get(&body.room_id, &StateEventType::RoomPowerLevels, "") + .await + .map_err(|_| err!(Database("Found room without m.room.create event.")))? .content .get(), - ) - .map_err(|_| Error::bad_database("Invalid room event in database."))?; + )?; // Setting events_default and invite to the greater of 50 and users_default + 1 let new_level = max( @@ -800,9 +826,7 @@ pub(crate) async fn upgrade_room_route( power_levels_event_content .users_default .checked_add(int!(1)) - .ok_or_else(|| { - Error::BadRequest(ErrorKind::BadJson, "users_default power levels event content is not valid") - })?, + .ok_or_else(|| err!(Request(BadJson("users_default power levels event content is not valid"))))?, ); power_levels_event_content.events_default = new_level; power_levels_event_content.invite = new_level; @@ -921,8 +945,9 @@ async fn room_alias_check( if services .rooms .alias - .resolve_local_alias(&full_room_alias)? - .is_some() + .resolve_local_alias(&full_room_alias) + .await + .is_ok() { return Err(Error::BadRequest(ErrorKind::RoomInUse, "Room alias already exists.")); } diff --git a/src/api/client/search.rs b/src/api/client/search.rs index b143bd2c7..b073640e8 100644 --- a/src/api/client/search.rs +++ b/src/api/client/search.rs @@ -1,6 +1,12 @@ use std::collections::BTreeMap; use axum::extract::State; +use conduit::{ + debug, + utils::{IterStream, ReadyExt}, + Err, +}; +use futures::{FutureExt, StreamExt}; use ruma::{ api::client::{ error::ErrorKind, @@ -13,7 +19,6 @@ use ruma::{ serde::Raw, uint, OwnedRoomId, }; -use tracing::debug; use crate::{Error, Result, Ruma}; @@ -32,14 +37,17 @@ pub(crate) async fn search_events_route( let filter = &search_criteria.filter; let include_state = &search_criteria.include_state; - let room_ids = filter.rooms.clone().unwrap_or_else(|| { + let room_ids = if let Some(room_ids) = &filter.rooms { + room_ids.clone() + } else { services .rooms .state_cache .rooms_joined(sender_user) - .filter_map(Result::ok) + .map(ToOwned::to_owned) .collect() - }); + .await + }; // Use limit or else 10, with maximum 100 let limit: usize = filter @@ -53,27 +61,30 @@ pub(crate) async fn search_events_route( if include_state.is_some_and(|include_state| include_state) { for room_id in &room_ids { - if !services.rooms.state_cache.is_joined(sender_user, room_id)? { - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "You don't have permission to view this room.", - )); + if !services + .rooms + .state_cache + .is_joined(sender_user, room_id) + .await + { + return Err!(Request(Forbidden("You don't have permission to view this room."))); } // check if sender_user can see state events if services .rooms .state_accessor - .user_can_see_state_events(sender_user, room_id)? + .user_can_see_state_events(sender_user, room_id) + .await { - let room_state = services + let room_state: Vec<_> = services .rooms .state_accessor .room_state_full(room_id) .await? .values() .map(|pdu| pdu.to_state_event()) - .collect::>(); + .collect(); debug!("Room state: {:?}", room_state); @@ -87,10 +98,15 @@ pub(crate) async fn search_events_route( } } - let mut searches = Vec::new(); + let mut search_vecs = Vec::new(); for room_id in &room_ids { - if !services.rooms.state_cache.is_joined(sender_user, room_id)? { + if !services + .rooms + .state_cache + .is_joined(sender_user, room_id) + .await + { return Err(Error::BadRequest( ErrorKind::forbidden(), "You don't have permission to view this room.", @@ -100,12 +116,18 @@ pub(crate) async fn search_events_route( if let Some(search) = services .rooms .search - .search_pdus(room_id, &search_criteria.search_term)? + .search_pdus(room_id, &search_criteria.search_term) + .await { - searches.push(search.0.peekable()); + search_vecs.push(search.0); } } + let mut searches: Vec<_> = search_vecs + .iter() + .map(|vec| vec.iter().peekable()) + .collect(); + let skip: usize = match body.next_batch.as_ref().map(|s| s.parse()) { Some(Ok(s)) => s, Some(Err(_)) => return Err(Error::BadRequest(ErrorKind::InvalidParam, "Invalid next_batch token.")), @@ -118,8 +140,8 @@ pub(crate) async fn search_events_route( for _ in 0..next_batch { if let Some(s) = searches .iter_mut() - .map(|s| (s.peek().cloned(), s)) - .max_by_key(|(peek, _)| peek.clone()) + .map(|s| (s.peek().copied(), s)) + .max_by_key(|(peek, _)| *peek) .and_then(|(_, i)| i.next()) { results.push(s); @@ -127,42 +149,38 @@ pub(crate) async fn search_events_route( } let results: Vec<_> = results - .iter() + .into_iter() .skip(skip) - .filter_map(|result| { + .stream() + .filter_map(|id| services.rooms.timeline.get_pdu_from_id(id).map(Result::ok)) + .ready_filter(|pdu| !pdu.is_redacted()) + .filter_map(|pdu| async move { services .rooms - .timeline - .get_pdu_from_id(result) - .ok()? - .filter(|pdu| { - !pdu.is_redacted() - && services - .rooms - .state_accessor - .user_can_see_event(sender_user, &pdu.room_id, &pdu.event_id) - .unwrap_or(false) - }) - .map(|pdu| pdu.to_room_event()) - }) - .map(|result| { - Ok::<_, Error>(SearchResult { - context: EventContextResult { - end: None, - events_after: Vec::new(), - events_before: Vec::new(), - profile_info: BTreeMap::new(), - start: None, - }, - rank: None, - result: Some(result), - }) + .state_accessor + .user_can_see_event(sender_user, &pdu.room_id, &pdu.event_id) + .await + .then_some(pdu) }) - .filter_map(Result::ok) .take(limit) - .collect(); + .map(|pdu| pdu.to_room_event()) + .map(|result| SearchResult { + context: EventContextResult { + end: None, + events_after: Vec::new(), + events_before: Vec::new(), + profile_info: BTreeMap::new(), + start: None, + }, + rank: None, + result: Some(result), + }) + .collect() + .boxed() + .await; let more_unloaded_results = searches.iter_mut().any(|s| s.peek().is_some()); + let next_batch = more_unloaded_results.then(|| next_batch.to_string()); Ok(search_events::v3::Response::new(ResultCategories { diff --git a/src/api/client/session.rs b/src/api/client/session.rs index 4702b0ec1..6347a2c95 100644 --- a/src/api/client/session.rs +++ b/src/api/client/session.rs @@ -1,5 +1,7 @@ use axum::extract::State; use axum_client_ip::InsecureClientIp; +use conduit::{debug, err, info, utils::ReadyExt, warn, Err}; +use futures::StreamExt; use ruma::{ api::client::{ error::ErrorKind, @@ -19,7 +21,6 @@ use ruma::{ UserId, }; use serde::Deserialize; -use tracing::{debug, info, warn}; use super::{DEVICE_ID_LENGTH, TOKEN_LENGTH}; use crate::{utils, utils::hash, Error, Result, Ruma}; @@ -79,21 +80,22 @@ pub(crate) async fn login_route( UserId::parse(user) } else { warn!("Bad login type: {:?}", &body.login_info); - return Err(Error::BadRequest(ErrorKind::forbidden(), "Bad login type.")); + return Err!(Request(Forbidden("Bad login type."))); } .map_err(|_| Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid."))?; let hash = services .users - .password_hash(&user_id)? - .ok_or(Error::BadRequest(ErrorKind::forbidden(), "Wrong username or password."))?; + .password_hash(&user_id) + .await + .map_err(|_| err!(Request(Forbidden("Wrong username or password."))))?; if hash.is_empty() { - return Err(Error::BadRequest(ErrorKind::UserDeactivated, "The user has been deactivated")); + return Err!(Request(UserDeactivated("The user has been deactivated"))); } if hash::verify_password(password, &hash).is_err() { - return Err(Error::BadRequest(ErrorKind::forbidden(), "Wrong username or password.")); + return Err!(Request(Forbidden("Wrong username or password."))); } user_id @@ -112,15 +114,12 @@ pub(crate) async fn login_route( let username = token.claims.sub.to_lowercase(); - UserId::parse_with_server_name(username, services.globals.server_name()).map_err(|e| { - warn!("Failed to parse username from user logging in: {e}"); - Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid.") - })? + UserId::parse_with_server_name(username, services.globals.server_name()) + .map_err(|e| err!(Request(InvalidUsername(debug_error!(?e, "Failed to parse login username")))))? } else { - return Err(Error::BadRequest( - ErrorKind::Unknown, - "Token login is not supported (server has no jwt decoding key).", - )); + return Err!(Request(Unknown( + "Token login is not supported (server has no jwt decoding key)." + ))); } }, #[allow(deprecated)] @@ -169,23 +168,32 @@ pub(crate) async fn login_route( let token = utils::random_string(TOKEN_LENGTH); // Determine if device_id was provided and exists in the db for this user - let device_exists = body.device_id.as_ref().map_or(false, |device_id| { + let device_exists = if body.device_id.is_some() { services .users .all_device_ids(&user_id) - .any(|x| x.as_ref().map_or(false, |v| v == device_id)) - }); + .ready_any(|v| v == device_id) + .await + } else { + false + }; if device_exists { - services.users.set_token(&user_id, &device_id, &token)?; + services + .users + .set_token(&user_id, &device_id, &token) + .await?; } else { - services.users.create_device( - &user_id, - &device_id, - &token, - body.initial_device_display_name.clone(), - Some(client.to_string()), - )?; + services + .users + .create_device( + &user_id, + &device_id, + &token, + body.initial_device_display_name.clone(), + Some(client.to_string()), + ) + .await?; } // send client well-known if specified so the client knows to reconfigure itself @@ -228,10 +236,13 @@ pub(crate) async fn logout_route( let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_device = body.sender_device.as_ref().expect("user is authenticated"); - services.users.remove_device(sender_user, sender_device)?; + services + .users + .remove_device(sender_user, sender_device) + .await; // send device list update for user after logout - services.users.mark_device_key_update(sender_user)?; + services.users.mark_device_key_update(sender_user).await; Ok(logout::v3::Response::new()) } @@ -256,12 +267,14 @@ pub(crate) async fn logout_all_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - for device_id in services.users.all_device_ids(sender_user).flatten() { - services.users.remove_device(sender_user, &device_id)?; - } + services + .users + .all_device_ids(sender_user) + .for_each(|device_id| services.users.remove_device(sender_user, device_id)) + .await; // send device list update for user after logout - services.users.mark_device_key_update(sender_user)?; + services.users.mark_device_key_update(sender_user).await; Ok(logout_all::v3::Response::new()) } diff --git a/src/api/client/state.rs b/src/api/client/state.rs index fd0496639..f9a4a7636 100644 --- a/src/api/client/state.rs +++ b/src/api/client/state.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use axum::extract::State; -use conduit::{debug_info, error, pdu::PduBuilder, Error, Result}; +use conduit::{err, error, pdu::PduBuilder, Err, Error, Result}; use ruma::{ api::client::{ error::ErrorKind, @@ -84,12 +84,10 @@ pub(crate) async fn get_state_events_route( if !services .rooms .state_accessor - .user_can_see_state_events(sender_user, &body.room_id)? + .user_can_see_state_events(sender_user, &body.room_id) + .await { - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "You don't have permission to view the room state.", - )); + return Err!(Request(Forbidden("You don't have permission to view the room state."))); } Ok(get_state_events::v3::Response { @@ -120,22 +118,25 @@ pub(crate) async fn get_state_events_for_key_route( if !services .rooms .state_accessor - .user_can_see_state_events(sender_user, &body.room_id)? + .user_can_see_state_events(sender_user, &body.room_id) + .await { - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "You don't have permission to view the room state.", - )); + return Err!(Request(Forbidden("You don't have permission to view the room state."))); } let event = services .rooms .state_accessor - .room_state_get(&body.room_id, &body.event_type, &body.state_key)? - .ok_or_else(|| { - debug_info!("State event {:?} not found in room {:?}", &body.event_type, &body.room_id); - Error::BadRequest(ErrorKind::NotFound, "State event not found.") + .room_state_get(&body.room_id, &body.event_type, &body.state_key) + .await + .map_err(|_| { + err!(Request(NotFound(error!( + room_id = ?body.room_id, + event_type = ?body.event_type, + "State event not found in room.", + )))) })?; + if body .format .as_ref() @@ -204,7 +205,7 @@ async fn send_state_event_for_key_helper( async fn allowed_to_send_state_event( services: &Services, room_id: &RoomId, event_type: &StateEventType, json: &Raw, -) -> Result<()> { +) -> Result { match event_type { // Forbid m.room.encryption if encryption is disabled StateEventType::RoomEncryption => { @@ -214,7 +215,7 @@ async fn allowed_to_send_state_event( }, // admin room is a sensitive room, it should not ever be made public StateEventType::RoomJoinRules => { - if let Some(admin_room_id) = services.admin.get_admin_room()? { + if let Ok(admin_room_id) = services.admin.get_admin_room().await { if admin_room_id == room_id { if let Ok(join_rule) = serde_json::from_str::(json.json().get()) { if join_rule.join_rule == JoinRule::Public { @@ -229,7 +230,7 @@ async fn allowed_to_send_state_event( }, // admin room is a sensitive room, it should not ever be made world readable StateEventType::RoomHistoryVisibility => { - if let Some(admin_room_id) = services.admin.get_admin_room()? { + if let Ok(admin_room_id) = services.admin.get_admin_room().await { if admin_room_id == room_id { if let Ok(visibility_content) = serde_json::from_str::(json.json().get()) @@ -254,23 +255,27 @@ async fn allowed_to_send_state_event( } for alias in aliases { - if !services.globals.server_is_ours(alias.server_name()) - || services - .rooms - .alias - .resolve_local_alias(&alias)? - .filter(|room| room == room_id) // Make sure it's the right room - .is_none() + if !services.globals.server_is_ours(alias.server_name()) { + return Err!(Request(Forbidden("canonical_alias must be for this server"))); + } + + if !services + .rooms + .alias + .resolve_local_alias(&alias) + .await + .is_ok_and(|room| room == room_id) + // Make sure it's the right room { - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "You are only allowed to send canonical_alias events when its aliases already exist", - )); + return Err!(Request(Forbidden( + "You are only allowed to send canonical_alias events when its aliases already exist" + ))); } } } }, _ => (), } + Ok(()) } diff --git a/src/api/client/sync.rs b/src/api/client/sync.rs index eb534205e..adb4d8da7 100644 --- a/src/api/client/sync.rs +++ b/src/api/client/sync.rs @@ -6,10 +6,15 @@ use std::{ use axum::extract::State; use conduit::{ - debug, error, - utils::math::{ruma_from_u64, ruma_from_usize, usize_from_ruma, usize_from_u64_truncated}, - warn, Err, PduCount, + debug, err, error, is_equal_to, + result::IntoIsOk, + utils::{ + math::{ruma_from_u64, ruma_from_usize, usize_from_ruma, usize_from_u64_truncated}, + BoolExt, IterStream, ReadyExt, TryFutureExtExt, + }, + warn, PduCount, }; +use futures::{pin_mut, FutureExt, StreamExt, TryFutureExt}; use ruma::{ api::client::{ error::ErrorKind, @@ -108,7 +113,8 @@ pub(crate) async fn sync_events_route( if services.globals.allow_local_presence() { services .presence - .ping_presence(&sender_user, &body.set_presence)?; + .ping_presence(&sender_user, &body.set_presence) + .await?; } // Setup watchers, so if there's no response, we can wait for them @@ -124,7 +130,8 @@ pub(crate) async fn sync_events_route( Some(Filter::FilterDefinition(filter)) => filter, Some(Filter::FilterId(filter_id)) => services .users - .get_filter(&sender_user, &filter_id)? + .get_filter(&sender_user, &filter_id) + .await .unwrap_or_default(), }; @@ -157,24 +164,27 @@ pub(crate) async fn sync_events_route( services .users .keys_changed(sender_user.as_ref(), since, None) - .filter_map(Result::ok), + .map(ToOwned::to_owned) + .collect::>() + .await, ); if services.globals.allow_local_presence() { process_presence_updates(&services, &mut presence_updates, since, &sender_user).await?; } - let all_joined_rooms = services + let all_joined_rooms: Vec<_> = services .rooms .state_cache .rooms_joined(&sender_user) - .collect::>(); + .map(ToOwned::to_owned) + .collect() + .await; // Coalesce database writes for the remainder of this scope. let _cork = services.db.cork_and_flush(); for room_id in all_joined_rooms { - let room_id = room_id?; if let Ok(joined_room) = load_joined_room( &services, &sender_user, @@ -203,12 +213,14 @@ pub(crate) async fn sync_events_route( .rooms .state_cache .rooms_left(&sender_user) - .collect(); + .collect() + .await; + for result in all_left_rooms { handle_left_room( &services, since, - &result?.0, + &result.0, &sender_user, &mut left_rooms, &next_batch_string, @@ -224,10 +236,10 @@ pub(crate) async fn sync_events_route( .rooms .state_cache .rooms_invited(&sender_user) - .collect(); - for result in all_invited_rooms { - let (room_id, invite_state_events) = result?; + .collect() + .await; + for (room_id, invite_state_events) in all_invited_rooms { // Get and drop the lock to wait for remaining operations to finish let insert_lock = services.rooms.timeline.mutex_insert.lock(&room_id).await; drop(insert_lock); @@ -235,7 +247,9 @@ pub(crate) async fn sync_events_route( let invite_count = services .rooms .state_cache - .get_invite_count(&room_id, &sender_user)?; + .get_invite_count(&room_id, &sender_user) + .await + .ok(); // Invited before last sync if Some(since) >= invite_count { @@ -253,22 +267,8 @@ pub(crate) async fn sync_events_route( } for user_id in left_encrypted_users { - let dont_share_encrypted_room = services - .rooms - .user - .get_shared_rooms(vec![sender_user.clone(), user_id.clone()])? - .filter_map(Result::ok) - .filter_map(|other_room_id| { - Some( - services - .rooms - .state_accessor - .room_state_get(&other_room_id, &StateEventType::RoomEncryption, "") - .ok()? - .is_some(), - ) - }) - .all(|encrypted| !encrypted); + let dont_share_encrypted_room = !share_encrypted_room(&services, &sender_user, &user_id, None).await; + // If the user doesn't share an encrypted room with the target anymore, we need // to tell them if dont_share_encrypted_room { @@ -279,7 +279,8 @@ pub(crate) async fn sync_events_route( // Remove all to-device events the device received *last time* services .users - .remove_to_device_events(&sender_user, &sender_device, since)?; + .remove_to_device_events(&sender_user, &sender_device, since) + .await; let response = sync_events::v3::Response { next_batch: next_batch_string, @@ -298,7 +299,8 @@ pub(crate) async fn sync_events_route( account_data: GlobalAccountData { events: services .account_data - .changes_since(None, &sender_user, since)? + .changes_since(None, &sender_user, since) + .await? .into_iter() .filter_map(|e| extract_variant!(e, AnyRawAccountDataEvent::Global)) .collect(), @@ -309,11 +311,14 @@ pub(crate) async fn sync_events_route( }, device_one_time_keys_count: services .users - .count_one_time_keys(&sender_user, &sender_device)?, + .count_one_time_keys(&sender_user, &sender_device) + .await, to_device: ToDevice { events: services .users - .get_to_device_events(&sender_user, &sender_device)?, + .get_to_device_events(&sender_user, &sender_device) + .collect() + .await, }, // Fallback keys are not yet supported device_unused_fallback_key_types: None, @@ -351,14 +356,16 @@ async fn handle_left_room( let left_count = services .rooms .state_cache - .get_left_count(room_id, sender_user)?; + .get_left_count(room_id, sender_user) + .await + .ok(); // Left before last sync if Some(since) >= left_count { return Ok(()); } - if !services.rooms.metadata.exists(room_id)? { + if !services.rooms.metadata.exists(room_id).await { // This is just a rejected invite, not a room we know // Insert a leave event anyways let event = PduEvent { @@ -408,27 +415,29 @@ async fn handle_left_room( let since_shortstatehash = services .rooms .user - .get_token_shortstatehash(room_id, since)?; + .get_token_shortstatehash(room_id, since) + .await; let since_state_ids = match since_shortstatehash { - Some(s) => services.rooms.state_accessor.state_full_ids(s).await?, - None => HashMap::new(), + Ok(s) => services.rooms.state_accessor.state_full_ids(s).await?, + Err(_) => HashMap::new(), }; - let Some(left_event_id) = - services - .rooms - .state_accessor - .room_state_get_id(room_id, &StateEventType::RoomMember, sender_user.as_str())? + let Ok(left_event_id) = services + .rooms + .state_accessor + .room_state_get_id(room_id, &StateEventType::RoomMember, sender_user.as_str()) + .await else { error!("Left room but no left state event"); return Ok(()); }; - let Some(left_shortstatehash) = services + let Ok(left_shortstatehash) = services .rooms .state_accessor - .pdu_shortstatehash(&left_event_id)? + .pdu_shortstatehash(&left_event_id) + .await else { error!(event_id = %left_event_id, "Leave event has no state"); return Ok(()); @@ -443,14 +452,15 @@ async fn handle_left_room( let leave_shortstatekey = services .rooms .short - .get_or_create_shortstatekey(&StateEventType::RoomMember, sender_user.as_str())?; + .get_or_create_shortstatekey(&StateEventType::RoomMember, sender_user.as_str()) + .await; left_state_ids.insert(leave_shortstatekey, left_event_id); let mut i: u8 = 0; for (key, id) in left_state_ids { if full_state || since_state_ids.get(&key) != Some(&id) { - let (event_type, state_key) = services.rooms.short.get_statekey_from_short(key)?; + let (event_type, state_key) = services.rooms.short.get_statekey_from_short(key).await?; if !lazy_load_enabled || event_type != StateEventType::RoomMember @@ -458,7 +468,7 @@ async fn handle_left_room( // TODO: Delete the following line when this is resolved: https://github.com/vector-im/element-web/issues/22565 || (cfg!(feature = "element_hacks") && *sender_user == state_key) { - let Some(pdu) = services.rooms.timeline.get_pdu(&id)? else { + let Ok(pdu) = services.rooms.timeline.get_pdu(&id).await else { error!("Pdu in state not found: {}", id); continue; }; @@ -495,19 +505,25 @@ async fn handle_left_room( async fn process_presence_updates( services: &Services, presence_updates: &mut HashMap, since: u64, syncing_user: &UserId, ) -> Result<()> { + let presence_since = services.presence.presence_since(since); + // Take presence updates - for (user_id, _, presence_bytes) in services.presence.presence_since(since) { + pin_mut!(presence_since); + while let Some((user_id, _, presence_bytes)) = presence_since.next().await { if !services .rooms .state_cache - .user_sees_user(syncing_user, &user_id)? + .user_sees_user(syncing_user, &user_id) + .await { continue; } let presence_event = services .presence - .from_json_bytes_to_event(&presence_bytes, &user_id)?; + .from_json_bytes_to_event(&presence_bytes, &user_id) + .await?; + match presence_updates.entry(user_id) { Entry::Vacant(slot) => { slot.insert(presence_event); @@ -551,14 +567,14 @@ async fn load_joined_room( let insert_lock = services.rooms.timeline.mutex_insert.lock(room_id).await; drop(insert_lock); - let (timeline_pdus, limited) = load_timeline(services, sender_user, room_id, sincecount, 10)?; + let (timeline_pdus, limited) = load_timeline(services, sender_user, room_id, sincecount, 10).await?; let send_notification_counts = !timeline_pdus.is_empty() || services .rooms .user - .last_notification_read(sender_user, room_id)? - > since; + .last_notification_read(sender_user, room_id) + .await > since; let mut timeline_users = HashSet::new(); for (_, event) in &timeline_pdus { @@ -568,355 +584,382 @@ async fn load_joined_room( services .rooms .lazy_loading - .lazy_load_confirm_delivery(sender_user, sender_device, room_id, sincecount) - .await?; + .lazy_load_confirm_delivery(sender_user, sender_device, room_id, sincecount); // Database queries: - let Some(current_shortstatehash) = services.rooms.state.get_room_shortstatehash(room_id)? else { - return Err!(Database(error!("Room {room_id} has no state"))); - }; + let current_shortstatehash = services + .rooms + .state + .get_room_shortstatehash(room_id) + .await + .map_err(|_| err!(Database(error!("Room {room_id} has no state"))))?; let since_shortstatehash = services .rooms .user - .get_token_shortstatehash(room_id, since)?; + .get_token_shortstatehash(room_id, since) + .await + .ok(); - let (heroes, joined_member_count, invited_member_count, joined_since_last_sync, state_events) = - if timeline_pdus.is_empty() && since_shortstatehash == Some(current_shortstatehash) { - // No state changes - (Vec::new(), None, None, false, Vec::new()) - } else { - // Calculates joined_member_count, invited_member_count and heroes - let calculate_counts = || { - let joined_member_count = services - .rooms - .state_cache - .room_joined_count(room_id)? - .unwrap_or(0); - let invited_member_count = services - .rooms - .state_cache - .room_invited_count(room_id)? - .unwrap_or(0); + let (heroes, joined_member_count, invited_member_count, joined_since_last_sync, state_events) = if timeline_pdus + .is_empty() + && (since_shortstatehash.is_none() || since_shortstatehash.is_some_and(is_equal_to!(current_shortstatehash))) + { + // No state changes + (Vec::new(), None, None, false, Vec::new()) + } else { + // Calculates joined_member_count, invited_member_count and heroes + let calculate_counts = || async { + let joined_member_count = services + .rooms + .state_cache + .room_joined_count(room_id) + .await + .unwrap_or(0); - // Recalculate heroes (first 5 members) - let mut heroes: Vec = Vec::with_capacity(5); + let invited_member_count = services + .rooms + .state_cache + .room_invited_count(room_id) + .await + .unwrap_or(0); - if joined_member_count.saturating_add(invited_member_count) <= 5 { - // Go through all PDUs and for each member event, check if the user is still - // joined or invited until we have 5 or we reach the end + if joined_member_count.saturating_add(invited_member_count) > 5 { + return Ok::<_, Error>((Some(joined_member_count), Some(invited_member_count), Vec::new())); + } - for hero in services - .rooms - .timeline - .all_pdus(sender_user, room_id)? - .filter_map(Result::ok) // Ignore all broken pdus - .filter(|(_, pdu)| pdu.kind == TimelineEventType::RoomMember) - .map(|(_, pdu)| { - let content: RoomMemberEventContent = serde_json::from_str(pdu.content.get()) - .map_err(|_| Error::bad_database("Invalid member event in database."))?; - - if let Some(state_key) = &pdu.state_key { - let user_id = UserId::parse(state_key.clone()) - .map_err(|_| Error::bad_database("Invalid UserId in member PDU."))?; - - // The membership was and still is invite or join - if matches!(content.membership, MembershipState::Join | MembershipState::Invite) - && (services.rooms.state_cache.is_joined(&user_id, room_id)? - || services.rooms.state_cache.is_invited(&user_id, room_id)?) - { - Ok::<_, Error>(Some(user_id)) - } else { - Ok(None) - } - } else { - Ok(None) - } - }) - .filter_map(Result::ok) - // Filter for possible heroes - .flatten() - { - if heroes.contains(&hero) || hero == sender_user { - continue; - } + // Go through all PDUs and for each member event, check if the user is still + // joined or invited until we have 5 or we reach the end - heroes.push(hero); + // Recalculate heroes (first 5 members) + let heroes = services + .rooms + .timeline + .all_pdus(sender_user, room_id) + .await? + .ready_filter(|(_, pdu)| pdu.kind == TimelineEventType::RoomMember) + .filter_map(|(_, pdu)| async move { + let Ok(content) = serde_json::from_str::(pdu.content.get()) else { + return None; + }; + + let Some(state_key) = &pdu.state_key else { + return None; + }; + + let Ok(user_id) = UserId::parse(state_key) else { + return None; + }; + + if user_id == sender_user { + return None; } - } - Ok::<_, Error>((Some(joined_member_count), Some(invited_member_count), heroes)) - }; + // The membership was and still is invite or join + if !matches!(content.membership, MembershipState::Join | MembershipState::Invite) { + return None; + } - let since_sender_member: Option = since_shortstatehash - .and_then(|shortstatehash| { - services + if !services .rooms - .state_accessor - .state_get(shortstatehash, &StateEventType::RoomMember, sender_user.as_str()) - .transpose() + .state_cache + .is_joined(&user_id, room_id) + .await && services + .rooms + .state_cache + .is_invited(&user_id, room_id) + .await + { + return None; + } + + Some(user_id) }) - .transpose()? - .and_then(|pdu| { - serde_json::from_str(pdu.content.get()) - .map_err(|_| Error::bad_database("Invalid PDU in database.")) - .ok() - }); + .collect::>() + .await; + + Ok::<_, Error>(( + Some(joined_member_count), + Some(invited_member_count), + heroes.into_iter().collect::>(), + )) + }; - let joined_since_last_sync = - since_sender_member.map_or(true, |member| member.membership != MembershipState::Join); + let since_sender_member: Option = if let Some(short) = since_shortstatehash { + services + .rooms + .state_accessor + .state_get(short, &StateEventType::RoomMember, sender_user.as_str()) + .await + .and_then(|pdu| serde_json::from_str(pdu.content.get()).map_err(Into::into)) + .ok() + } else { + None + }; - if since_shortstatehash.is_none() || joined_since_last_sync { - // Probably since = 0, we will do an initial sync + let joined_since_last_sync = + since_sender_member.map_or(true, |member| member.membership != MembershipState::Join); - let (joined_member_count, invited_member_count, heroes) = calculate_counts()?; + if since_shortstatehash.is_none() || joined_since_last_sync { + // Probably since = 0, we will do an initial sync - let current_state_ids = services - .rooms - .state_accessor - .state_full_ids(current_shortstatehash) - .await?; + let (joined_member_count, invited_member_count, heroes) = calculate_counts().await?; - let mut state_events = Vec::new(); - let mut lazy_loaded = HashSet::new(); + let current_state_ids = services + .rooms + .state_accessor + .state_full_ids(current_shortstatehash) + .await?; - let mut i: u8 = 0; - for (shortstatekey, id) in current_state_ids { - let (event_type, state_key) = services - .rooms - .short - .get_statekey_from_short(shortstatekey)?; + let mut state_events = Vec::new(); + let mut lazy_loaded = HashSet::new(); - if event_type != StateEventType::RoomMember { - let Some(pdu) = services.rooms.timeline.get_pdu(&id)? else { - error!("Pdu in state not found: {}", id); - continue; - }; - state_events.push(pdu); + let mut i: u8 = 0; + for (shortstatekey, id) in current_state_ids { + let (event_type, state_key) = services + .rooms + .short + .get_statekey_from_short(shortstatekey) + .await?; - i = i.wrapping_add(1); - if i % 100 == 0 { - tokio::task::yield_now().await; - } - } else if !lazy_load_enabled + if event_type != StateEventType::RoomMember { + let Ok(pdu) = services.rooms.timeline.get_pdu(&id).await else { + error!("Pdu in state not found: {id}"); + continue; + }; + state_events.push(pdu); + + i = i.wrapping_add(1); + if i % 100 == 0 { + tokio::task::yield_now().await; + } + } else if !lazy_load_enabled || full_state || timeline_users.contains(&state_key) // TODO: Delete the following line when this is resolved: https://github.com/vector-im/element-web/issues/22565 || (cfg!(feature = "element_hacks") && *sender_user == state_key) - { - let Some(pdu) = services.rooms.timeline.get_pdu(&id)? else { - error!("Pdu in state not found: {}", id); - continue; - }; + { + let Ok(pdu) = services.rooms.timeline.get_pdu(&id).await else { + error!("Pdu in state not found: {id}"); + continue; + }; - // This check is in case a bad user ID made it into the database - if let Ok(uid) = UserId::parse(&state_key) { - lazy_loaded.insert(uid); - } - state_events.push(pdu); + // This check is in case a bad user ID made it into the database + if let Ok(uid) = UserId::parse(&state_key) { + lazy_loaded.insert(uid); + } + state_events.push(pdu); - i = i.wrapping_add(1); - if i % 100 == 0 { - tokio::task::yield_now().await; - } + i = i.wrapping_add(1); + if i % 100 == 0 { + tokio::task::yield_now().await; } } + } - // Reset lazy loading because this is an initial sync - services - .rooms - .lazy_loading - .lazy_load_reset(sender_user, sender_device, room_id)?; + // Reset lazy loading because this is an initial sync + services + .rooms + .lazy_loading + .lazy_load_reset(sender_user, sender_device, room_id) + .await; + + // The state_events above should contain all timeline_users, let's mark them as + // lazy loaded. + services.rooms.lazy_loading.lazy_load_mark_sent( + sender_user, + sender_device, + room_id, + lazy_loaded, + next_batchcount, + ); - // The state_events above should contain all timeline_users, let's mark them as - // lazy loaded. - services - .rooms - .lazy_loading - .lazy_load_mark_sent(sender_user, sender_device, room_id, lazy_loaded, next_batchcount) - .await; + (heroes, joined_member_count, invited_member_count, true, state_events) + } else { + // Incremental /sync + let since_shortstatehash = since_shortstatehash.expect("missing since_shortstatehash on incremental sync"); - (heroes, joined_member_count, invited_member_count, true, state_events) - } else { - // Incremental /sync - let since_shortstatehash = since_shortstatehash.unwrap(); + let mut delta_state_events = Vec::new(); - let mut delta_state_events = Vec::new(); + if since_shortstatehash != current_shortstatehash { + let current_state_ids = services + .rooms + .state_accessor + .state_full_ids(current_shortstatehash) + .await?; - if since_shortstatehash != current_shortstatehash { - let current_state_ids = services - .rooms - .state_accessor - .state_full_ids(current_shortstatehash) - .await?; - let since_state_ids = services - .rooms - .state_accessor - .state_full_ids(since_shortstatehash) - .await?; + let since_state_ids = services + .rooms + .state_accessor + .state_full_ids(since_shortstatehash) + .await?; - for (key, id) in current_state_ids { - if full_state || since_state_ids.get(&key) != Some(&id) { - let Some(pdu) = services.rooms.timeline.get_pdu(&id)? else { - error!("Pdu in state not found: {}", id); - continue; - }; + for (key, id) in current_state_ids { + if full_state || since_state_ids.get(&key) != Some(&id) { + let Ok(pdu) = services.rooms.timeline.get_pdu(&id).await else { + error!("Pdu in state not found: {id}"); + continue; + }; - delta_state_events.push(pdu); - tokio::task::yield_now().await; - } + delta_state_events.push(pdu); + tokio::task::yield_now().await; } } + } - let encrypted_room = services - .rooms - .state_accessor - .state_get(current_shortstatehash, &StateEventType::RoomEncryption, "")? - .is_some(); + let encrypted_room = services + .rooms + .state_accessor + .state_get(current_shortstatehash, &StateEventType::RoomEncryption, "") + .await + .is_ok(); - let since_encryption = services.rooms.state_accessor.state_get( - since_shortstatehash, - &StateEventType::RoomEncryption, - "", - )?; + let since_encryption = services + .rooms + .state_accessor + .state_get(since_shortstatehash, &StateEventType::RoomEncryption, "") + .await; - // Calculations: - let new_encrypted_room = encrypted_room && since_encryption.is_none(); + // Calculations: + let new_encrypted_room = encrypted_room && since_encryption.is_err(); - let send_member_count = delta_state_events - .iter() - .any(|event| event.kind == TimelineEventType::RoomMember); + let send_member_count = delta_state_events + .iter() + .any(|event| event.kind == TimelineEventType::RoomMember); - if encrypted_room { - for state_event in &delta_state_events { - if state_event.kind != TimelineEventType::RoomMember { - continue; - } + if encrypted_room { + for state_event in &delta_state_events { + if state_event.kind != TimelineEventType::RoomMember { + continue; + } - if let Some(state_key) = &state_event.state_key { - let user_id = UserId::parse(state_key.clone()) - .map_err(|_| Error::bad_database("Invalid UserId in member PDU."))?; + if let Some(state_key) = &state_event.state_key { + let user_id = UserId::parse(state_key.clone()) + .map_err(|_| Error::bad_database("Invalid UserId in member PDU."))?; - if user_id == sender_user { - continue; - } + if user_id == sender_user { + continue; + } - let new_membership = - serde_json::from_str::(state_event.content.get()) - .map_err(|_| Error::bad_database("Invalid PDU in database."))? - .membership; + let new_membership = serde_json::from_str::(state_event.content.get()) + .map_err(|_| Error::bad_database("Invalid PDU in database."))? + .membership; - match new_membership { - MembershipState::Join => { - // A new user joined an encrypted room - if !share_encrypted_room(services, sender_user, &user_id, room_id)? { - device_list_updates.insert(user_id); - } - }, - MembershipState::Leave => { - // Write down users that have left encrypted rooms we are in - left_encrypted_users.insert(user_id); - }, - _ => {}, - } + match new_membership { + MembershipState::Join => { + // A new user joined an encrypted room + if !share_encrypted_room(services, sender_user, &user_id, Some(room_id)).await { + device_list_updates.insert(user_id); + } + }, + MembershipState::Leave => { + // Write down users that have left encrypted rooms we are in + left_encrypted_users.insert(user_id); + }, + _ => {}, } } } + } - if joined_since_last_sync && encrypted_room || new_encrypted_room { - // If the user is in a new encrypted room, give them all joined users - device_list_updates.extend( - services - .rooms - .state_cache - .room_members(room_id) - .flatten() - .filter(|user_id| { - // Don't send key updates from the sender to the sender - sender_user != user_id - }) - .filter(|user_id| { - // Only send keys if the sender doesn't share an encrypted room with the target - // already - !share_encrypted_room(services, sender_user, user_id, room_id).unwrap_or(false) - }), - ); - } + if joined_since_last_sync && encrypted_room || new_encrypted_room { + // If the user is in a new encrypted room, give them all joined users + device_list_updates.extend( + services + .rooms + .state_cache + .room_members(room_id) + // Don't send key updates from the sender to the sender + .ready_filter(|user_id| sender_user != *user_id) + // Only send keys if the sender doesn't share an encrypted room with the target + // already + .filter_map(|user_id| { + share_encrypted_room(services, sender_user, user_id, Some(room_id)) + .map(|res| res.or_some(user_id.to_owned())) + }) + .collect::>() + .await, + ); + } - let (joined_member_count, invited_member_count, heroes) = if send_member_count { - calculate_counts()? - } else { - (None, None, Vec::new()) - }; + let (joined_member_count, invited_member_count, heroes) = if send_member_count { + calculate_counts().await? + } else { + (None, None, Vec::new()) + }; - let mut state_events = delta_state_events; - let mut lazy_loaded = HashSet::new(); - - // Mark all member events we're returning as lazy-loaded - for pdu in &state_events { - if pdu.kind == TimelineEventType::RoomMember { - match UserId::parse( - pdu.state_key - .as_ref() - .expect("State event has state key") - .clone(), - ) { - Ok(state_key_userid) => { - lazy_loaded.insert(state_key_userid); - }, - Err(e) => error!("Invalid state key for member event: {}", e), - } + let mut state_events = delta_state_events; + let mut lazy_loaded = HashSet::new(); + + // Mark all member events we're returning as lazy-loaded + for pdu in &state_events { + if pdu.kind == TimelineEventType::RoomMember { + match UserId::parse( + pdu.state_key + .as_ref() + .expect("State event has state key") + .clone(), + ) { + Ok(state_key_userid) => { + lazy_loaded.insert(state_key_userid); + }, + Err(e) => error!("Invalid state key for member event: {}", e), } } + } - // Fetch contextual member state events for events from the timeline, and - // mark them as lazy-loaded as well. - for (_, event) in &timeline_pdus { - if lazy_loaded.contains(&event.sender) { - continue; - } + // Fetch contextual member state events for events from the timeline, and + // mark them as lazy-loaded as well. + for (_, event) in &timeline_pdus { + if lazy_loaded.contains(&event.sender) { + continue; + } - if !services.rooms.lazy_loading.lazy_load_was_sent_before( - sender_user, - sender_device, - room_id, - &event.sender, - )? || lazy_load_send_redundant + if !services + .rooms + .lazy_loading + .lazy_load_was_sent_before(sender_user, sender_device, room_id, &event.sender) + .await || lazy_load_send_redundant + { + if let Ok(member_event) = services + .rooms + .state_accessor + .room_state_get(room_id, &StateEventType::RoomMember, event.sender.as_str()) + .await { - if let Some(member_event) = services.rooms.state_accessor.room_state_get( - room_id, - &StateEventType::RoomMember, - event.sender.as_str(), - )? { - lazy_loaded.insert(event.sender.clone()); - state_events.push(member_event); - } + lazy_loaded.insert(event.sender.clone()); + state_events.push(member_event); } } + } - services - .rooms - .lazy_loading - .lazy_load_mark_sent(sender_user, sender_device, room_id, lazy_loaded, next_batchcount) - .await; + services.rooms.lazy_loading.lazy_load_mark_sent( + sender_user, + sender_device, + room_id, + lazy_loaded, + next_batchcount, + ); - ( - heroes, - joined_member_count, - invited_member_count, - joined_since_last_sync, - state_events, - ) - } - }; + ( + heroes, + joined_member_count, + invited_member_count, + joined_since_last_sync, + state_events, + ) + } + }; // Look for device list updates in this room device_list_updates.extend( services .users .keys_changed(room_id.as_ref(), since, None) - .filter_map(Result::ok), + .map(ToOwned::to_owned) + .collect::>() + .await, ); let notification_count = if send_notification_counts { @@ -924,7 +967,8 @@ async fn load_joined_room( services .rooms .user - .notification_count(sender_user, room_id)? + .notification_count(sender_user, room_id) + .await .try_into() .expect("notification count can't go that high"), ) @@ -937,7 +981,8 @@ async fn load_joined_room( services .rooms .user - .highlight_count(sender_user, room_id)? + .highlight_count(sender_user, room_id) + .await .try_into() .expect("highlight count can't go that high"), ) @@ -966,9 +1011,9 @@ async fn load_joined_room( .rooms .read_receipt .readreceipts_since(room_id, since) - .filter_map(Result::ok) // Filter out buggy events .map(|(_, _, v)| v) - .collect(); + .collect() + .await; if services.rooms.typing.last_typing_update(room_id).await? > since { edus.push( @@ -985,13 +1030,15 @@ async fn load_joined_room( services .rooms .user - .associate_token_shortstatehash(room_id, next_batch, current_shortstatehash)?; + .associate_token_shortstatehash(room_id, next_batch, current_shortstatehash) + .await; Ok(JoinedRoom { account_data: RoomAccountData { events: services .account_data - .changes_since(Some(room_id), sender_user, since)? + .changes_since(Some(room_id), sender_user, since) + .await? .into_iter() .filter_map(|e| extract_variant!(e, AnyRawAccountDataEvent::Room)) .collect(), @@ -1023,41 +1070,37 @@ async fn load_joined_room( }) } -fn load_timeline( +async fn load_timeline( services: &Services, sender_user: &UserId, room_id: &RoomId, roomsincecount: PduCount, limit: u64, ) -> Result<(Vec<(PduCount, PduEvent)>, bool), Error> { let timeline_pdus; let limited = if services .rooms .timeline - .last_timeline_count(sender_user, room_id)? + .last_timeline_count(sender_user, room_id) + .await? > roomsincecount { let mut non_timeline_pdus = services .rooms .timeline - .pdus_until(sender_user, room_id, PduCount::max())? - .filter_map(|r| { - // Filter out buggy events - if r.is_err() { - error!("Bad pdu in pdus_since: {:?}", r); - } - r.ok() - }) - .take_while(|(pducount, _)| pducount > &roomsincecount); + .pdus_until(sender_user, room_id, PduCount::max()) + .await? + .ready_take_while(|(pducount, _)| pducount > &roomsincecount); // Take the last events for the timeline timeline_pdus = non_timeline_pdus .by_ref() .take(usize_from_u64_truncated(limit)) .collect::>() + .await .into_iter() .rev() .collect::>(); // They /sync response doesn't always return all messages, so we say the output // is limited unless there are events in non_timeline_pdus - non_timeline_pdus.next().is_some() + non_timeline_pdus.next().await.is_some() } else { timeline_pdus = Vec::new(); false @@ -1065,26 +1108,22 @@ fn load_timeline( Ok((timeline_pdus, limited)) } -fn share_encrypted_room( - services: &Services, sender_user: &UserId, user_id: &UserId, ignore_room: &RoomId, -) -> Result { - Ok(services +async fn share_encrypted_room( + services: &Services, sender_user: &UserId, user_id: &UserId, ignore_room: Option<&RoomId>, +) -> bool { + services .rooms .user - .get_shared_rooms(vec![sender_user.to_owned(), user_id.to_owned()])? - .filter_map(Result::ok) - .filter(|room_id| room_id != ignore_room) - .filter_map(|other_room_id| { - Some( - services - .rooms - .state_accessor - .room_state_get(&other_room_id, &StateEventType::RoomEncryption, "") - .ok()? - .is_some(), - ) + .get_shared_rooms(sender_user, user_id) + .ready_filter(|&room_id| Some(room_id) != ignore_room) + .any(|other_room_id| { + services + .rooms + .state_accessor + .room_state_get(other_room_id, &StateEventType::RoomEncryption, "") + .map(Result::into_is_ok) }) - .any(|encrypted| encrypted)) + .await } /// POST `/_matrix/client/unstable/org.matrix.msc3575/sync` @@ -1114,7 +1153,7 @@ pub(crate) async fn sync_events_v4_route( if globalsince != 0 && !services - .users + .sync .remembered(sender_user.clone(), sender_device.clone(), conn_id.clone()) { debug!("Restarting sync stream because it was gone from the database"); @@ -1127,41 +1166,43 @@ pub(crate) async fn sync_events_v4_route( if globalsince == 0 { services - .users + .sync .forget_sync_request_connection(sender_user.clone(), sender_device.clone(), conn_id.clone()); } // Get sticky parameters from cache let known_rooms = services - .users + .sync .update_sync_request_with_cache(sender_user.clone(), sender_device.clone(), &mut body); - let all_joined_rooms = services + let all_joined_rooms: Vec<_> = services .rooms .state_cache .rooms_joined(&sender_user) - .filter_map(Result::ok) - .collect::>(); + .map(ToOwned::to_owned) + .collect() + .await; - let all_invited_rooms = services + let all_invited_rooms: Vec<_> = services .rooms .state_cache .rooms_invited(&sender_user) - .filter_map(Result::ok) .map(|r| r.0) - .collect::>(); + .collect() + .await; let all_rooms = all_joined_rooms .iter() - .cloned() - .chain(all_invited_rooms.clone()) + .chain(all_invited_rooms.iter()) + .map(Clone::clone) .collect(); if body.extensions.to_device.enabled.unwrap_or(false) { services .users - .remove_to_device_events(&sender_user, &sender_device, globalsince)?; + .remove_to_device_events(&sender_user, &sender_device, globalsince) + .await; } let mut left_encrypted_users = HashSet::new(); // Users that have left any encrypted rooms the sender was in @@ -1179,7 +1220,8 @@ pub(crate) async fn sync_events_v4_route( if body.extensions.account_data.enabled.unwrap_or(false) { account_data.global = services .account_data - .changes_since(None, &sender_user, globalsince)? + .changes_since(None, &sender_user, globalsince) + .await? .into_iter() .filter_map(|e| extract_variant!(e, AnyRawAccountDataEvent::Global)) .collect(); @@ -1190,7 +1232,8 @@ pub(crate) async fn sync_events_v4_route( room.clone(), services .account_data - .changes_since(Some(&room), &sender_user, globalsince)? + .changes_since(Some(&room), &sender_user, globalsince) + .await? .into_iter() .filter_map(|e| extract_variant!(e, AnyRawAccountDataEvent::Room)) .collect(), @@ -1205,40 +1248,42 @@ pub(crate) async fn sync_events_v4_route( services .users .keys_changed(sender_user.as_ref(), globalsince, None) - .filter_map(Result::ok), + .map(ToOwned::to_owned) + .collect::>() + .await, ); for room_id in &all_joined_rooms { - let Some(current_shortstatehash) = services.rooms.state.get_room_shortstatehash(room_id)? else { - error!("Room {} has no state", room_id); + let Ok(current_shortstatehash) = services.rooms.state.get_room_shortstatehash(room_id).await else { + error!("Room {room_id} has no state"); continue; }; let since_shortstatehash = services .rooms .user - .get_token_shortstatehash(room_id, globalsince)?; + .get_token_shortstatehash(room_id, globalsince) + .await + .ok(); - let since_sender_member: Option = since_shortstatehash - .and_then(|shortstatehash| { - services - .rooms - .state_accessor - .state_get(shortstatehash, &StateEventType::RoomMember, sender_user.as_str()) - .transpose() - }) - .transpose()? - .and_then(|pdu| { - serde_json::from_str(pdu.content.get()) - .map_err(|_| Error::bad_database("Invalid PDU in database.")) - .ok() - }); + let since_sender_member: Option = if let Some(short) = since_shortstatehash { + services + .rooms + .state_accessor + .state_get(short, &StateEventType::RoomMember, sender_user.as_str()) + .await + .and_then(|pdu| serde_json::from_str(pdu.content.get()).map_err(Into::into)) + .ok() + } else { + None + }; let encrypted_room = services .rooms .state_accessor - .state_get(current_shortstatehash, &StateEventType::RoomEncryption, "")? - .is_some(); + .state_get(current_shortstatehash, &StateEventType::RoomEncryption, "") + .await + .is_ok(); if let Some(since_shortstatehash) = since_shortstatehash { // Skip if there are only timeline changes @@ -1246,22 +1291,24 @@ pub(crate) async fn sync_events_v4_route( continue; } - let since_encryption = services.rooms.state_accessor.state_get( - since_shortstatehash, - &StateEventType::RoomEncryption, - "", - )?; + let since_encryption = services + .rooms + .state_accessor + .state_get(since_shortstatehash, &StateEventType::RoomEncryption, "") + .await; let joined_since_last_sync = since_sender_member.map_or(true, |member| member.membership != MembershipState::Join); - let new_encrypted_room = encrypted_room && since_encryption.is_none(); + let new_encrypted_room = encrypted_room && since_encryption.is_err(); + if encrypted_room { let current_state_ids = services .rooms .state_accessor .state_full_ids(current_shortstatehash) .await?; + let since_state_ids = services .rooms .state_accessor @@ -1270,8 +1317,8 @@ pub(crate) async fn sync_events_v4_route( for (key, id) in current_state_ids { if since_state_ids.get(&key) != Some(&id) { - let Some(pdu) = services.rooms.timeline.get_pdu(&id)? else { - error!("Pdu in state not found: {}", id); + let Ok(pdu) = services.rooms.timeline.get_pdu(&id).await else { + error!("Pdu in state not found: {id}"); continue; }; if pdu.kind == TimelineEventType::RoomMember { @@ -1291,7 +1338,9 @@ pub(crate) async fn sync_events_v4_route( match new_membership { MembershipState::Join => { // A new user joined an encrypted room - if !share_encrypted_room(&services, &sender_user, &user_id, room_id)? { + if !share_encrypted_room(&services, &sender_user, &user_id, Some(room_id)) + .await + { device_list_changes.insert(user_id); } }, @@ -1306,22 +1355,23 @@ pub(crate) async fn sync_events_v4_route( } } if joined_since_last_sync || new_encrypted_room { + let sender_user = &sender_user; // If the user is in a new encrypted room, give them all joined users device_list_changes.extend( services .rooms .state_cache .room_members(room_id) - .flatten() - .filter(|user_id| { - // Don't send key updates from the sender to the sender - &sender_user != user_id + // Don't send key updates from the sender to the sender + .ready_filter(|user_id| sender_user != user_id) + // Only send keys if the sender doesn't share an encrypted room with the target + // already + .filter_map(|user_id| { + share_encrypted_room(&services, sender_user, user_id, Some(room_id)) + .map(|res| res.or_some(user_id.to_owned())) }) - .filter(|user_id| { - // Only send keys if the sender doesn't share an encrypted room with the target - // already - !share_encrypted_room(&services, &sender_user, user_id, room_id).unwrap_or(false) - }), + .collect::>() + .await, ); } } @@ -1331,26 +1381,15 @@ pub(crate) async fn sync_events_v4_route( services .users .keys_changed(room_id.as_ref(), globalsince, None) - .filter_map(Result::ok), + .map(ToOwned::to_owned) + .collect::>() + .await, ); } + for user_id in left_encrypted_users { - let dont_share_encrypted_room = services - .rooms - .user - .get_shared_rooms(vec![sender_user.clone(), user_id.clone()])? - .filter_map(Result::ok) - .filter_map(|other_room_id| { - Some( - services - .rooms - .state_accessor - .room_state_get(&other_room_id, &StateEventType::RoomEncryption, "") - .ok()? - .is_some(), - ) - }) - .all(|encrypted| !encrypted); + let dont_share_encrypted_room = !share_encrypted_room(&services, &sender_user, &user_id, None).await; + // If the user doesn't share an encrypted room with the target anymore, we need // to tell them if dont_share_encrypted_room { @@ -1362,7 +1401,7 @@ pub(crate) async fn sync_events_v4_route( let mut lists = BTreeMap::new(); let mut todo_rooms = BTreeMap::new(); // and required state - for (list_id, list) in body.lists { + for (list_id, list) in &body.lists { let active_rooms = match list.filters.clone().and_then(|f| f.is_invite) { Some(true) => &all_invited_rooms, Some(false) => &all_joined_rooms, @@ -1371,23 +1410,23 @@ pub(crate) async fn sync_events_v4_route( let active_rooms = match list.filters.clone().map(|f| f.not_room_types) { Some(filter) if filter.is_empty() => active_rooms.clone(), - Some(value) => filter_rooms(active_rooms, State(services), &value, true), + Some(value) => filter_rooms(active_rooms, State(services), &value, true).await, None => active_rooms.clone(), }; let active_rooms = match list.filters.clone().map(|f| f.room_types) { Some(filter) if filter.is_empty() => active_rooms.clone(), - Some(value) => filter_rooms(&active_rooms, State(services), &value, false), + Some(value) => filter_rooms(&active_rooms, State(services), &value, false).await, None => active_rooms, }; let mut new_known_rooms = BTreeSet::new(); + let ranges = list.ranges.clone(); lists.insert( list_id.clone(), sync_events::v4::SyncList { - ops: list - .ranges + ops: ranges .into_iter() .map(|mut r| { r.0 = r.0.clamp( @@ -1396,29 +1435,34 @@ pub(crate) async fn sync_events_v4_route( ); r.1 = r.1.clamp(r.0, UInt::try_from(active_rooms.len().saturating_sub(1)).unwrap_or(UInt::MAX)); + let room_ids = if !active_rooms.is_empty() { active_rooms[usize_from_ruma(r.0)..=usize_from_ruma(r.1)].to_vec() } else { Vec::new() }; + new_known_rooms.extend(room_ids.iter().cloned()); for room_id in &room_ids { let todo_room = todo_rooms .entry(room_id.clone()) .or_insert((BTreeSet::new(), 0, u64::MAX)); + let limit = list .room_details .timeline_limit .map_or(10, u64::from) .min(100); + todo_room .0 .extend(list.room_details.required_state.iter().cloned()); + todo_room.1 = todo_room.1.max(limit); // 0 means unknown because it got out of date todo_room.2 = todo_room.2.min( known_rooms - .get(&list_id) + .get(list_id.as_str()) .and_then(|k| k.get(room_id)) .copied() .unwrap_or(0), @@ -1438,11 +1482,11 @@ pub(crate) async fn sync_events_v4_route( ); if let Some(conn_id) = &body.conn_id { - services.users.update_sync_known_rooms( + services.sync.update_sync_known_rooms( sender_user.clone(), sender_device.clone(), conn_id.clone(), - list_id, + list_id.clone(), new_known_rooms, globalsince, ); @@ -1451,7 +1495,7 @@ pub(crate) async fn sync_events_v4_route( let mut known_subscription_rooms = BTreeSet::new(); for (room_id, room) in &body.room_subscriptions { - if !services.rooms.metadata.exists(room_id)? { + if !services.rooms.metadata.exists(room_id).await { continue; } let todo_room = todo_rooms @@ -1477,7 +1521,7 @@ pub(crate) async fn sync_events_v4_route( } if let Some(conn_id) = &body.conn_id { - services.users.update_sync_known_rooms( + services.sync.update_sync_known_rooms( sender_user.clone(), sender_device.clone(), conn_id.clone(), @@ -1488,7 +1532,7 @@ pub(crate) async fn sync_events_v4_route( } if let Some(conn_id) = &body.conn_id { - services.users.update_sync_subscriptions( + services.sync.update_sync_subscriptions( sender_user.clone(), sender_device.clone(), conn_id.clone(), @@ -1509,12 +1553,13 @@ pub(crate) async fn sync_events_v4_route( .rooms .state_cache .invite_state(&sender_user, room_id) - .unwrap_or(None); + .await + .ok(); (timeline_pdus, limited) = (Vec::new(), true); } else { (timeline_pdus, limited) = - match load_timeline(&services, &sender_user, room_id, roomsincecount, *timeline_limit) { + match load_timeline(&services, &sender_user, room_id, roomsincecount, *timeline_limit).await { Ok(value) => value, Err(err) => { warn!("Encountered missing timeline in {}, error {}", room_id, err); @@ -1527,17 +1572,20 @@ pub(crate) async fn sync_events_v4_route( room_id.clone(), services .account_data - .changes_since(Some(room_id), &sender_user, *roomsince)? + .changes_since(Some(room_id), &sender_user, *roomsince) + .await? .into_iter() .filter_map(|e| extract_variant!(e, AnyRawAccountDataEvent::Room)) .collect(), ); - let room_receipts = services + let vector: Vec<_> = services .rooms .read_receipt - .readreceipts_since(room_id, *roomsince); - let vector: Vec<_> = room_receipts.into_iter().collect(); + .readreceipts_since(room_id, *roomsince) + .collect() + .await; + let receipt_size = vector.len(); receipts .rooms @@ -1584,41 +1632,41 @@ pub(crate) async fn sync_events_v4_route( let required_state = required_state_request .iter() - .map(|state| { + .stream() + .filter_map(|state| async move { services .rooms .state_accessor .room_state_get(room_id, &state.0, &state.1) + .await + .map(|s| s.to_sync_state_event()) + .ok() }) - .filter_map(Result::ok) - .flatten() - .map(|state| state.to_sync_state_event()) - .collect(); + .collect() + .await; // Heroes - let heroes = services + let heroes: Vec<_> = services .rooms .state_cache .room_members(room_id) - .filter_map(Result::ok) - .filter(|member| member != &sender_user) - .map(|member| { - Ok::<_, Error>( - services - .rooms - .state_accessor - .get_member(room_id, &member)? - .map(|memberevent| SlidingSyncRoomHero { - user_id: member, - name: memberevent.displayname, - avatar: memberevent.avatar_url, - }), - ) + .ready_filter(|member| member != &sender_user) + .filter_map(|user_id| { + services + .rooms + .state_accessor + .get_member(room_id, user_id) + .map_ok(|memberevent| SlidingSyncRoomHero { + user_id: user_id.into(), + name: memberevent.displayname, + avatar: memberevent.avatar_url, + }) + .ok() }) - .filter_map(Result::ok) - .flatten() .take(5) - .collect::>(); + .collect() + .await; + let name = match heroes.len().cmp(&(1_usize)) { Ordering::Greater => { let firsts = heroes[1..] @@ -1626,10 +1674,12 @@ pub(crate) async fn sync_events_v4_route( .map(|h| h.name.clone().unwrap_or_else(|| h.user_id.to_string())) .collect::>() .join(", "); + let last = heroes[0] .name .clone() .unwrap_or_else(|| heroes[0].user_id.to_string()); + Some(format!("{firsts} and {last}")) }, Ordering::Equal => Some( @@ -1650,11 +1700,17 @@ pub(crate) async fn sync_events_v4_route( rooms.insert( room_id.clone(), sync_events::v4::SlidingSyncRoom { - name: services.rooms.state_accessor.get_name(room_id)?.or(name), + name: services + .rooms + .state_accessor + .get_name(room_id) + .await + .ok() + .or(name), avatar: if let Some(heroes_avatar) = heroes_avatar { ruma::JsOption::Some(heroes_avatar) } else { - match services.rooms.state_accessor.get_avatar(room_id)? { + match services.rooms.state_accessor.get_avatar(room_id).await { ruma::JsOption::Some(avatar) => ruma::JsOption::from_option(avatar.url), ruma::JsOption::Null => ruma::JsOption::Null, ruma::JsOption::Undefined => ruma::JsOption::Undefined, @@ -1668,7 +1724,8 @@ pub(crate) async fn sync_events_v4_route( services .rooms .user - .highlight_count(&sender_user, room_id)? + .highlight_count(&sender_user, room_id) + .await .try_into() .expect("notification count can't go that high"), ), @@ -1676,7 +1733,8 @@ pub(crate) async fn sync_events_v4_route( services .rooms .user - .notification_count(&sender_user, room_id)? + .notification_count(&sender_user, room_id) + .await .try_into() .expect("notification count can't go that high"), ), @@ -1689,7 +1747,8 @@ pub(crate) async fn sync_events_v4_route( services .rooms .state_cache - .room_joined_count(room_id)? + .room_joined_count(room_id) + .await .unwrap_or(0) .try_into() .unwrap_or_else(|_| uint!(0)), @@ -1698,7 +1757,8 @@ pub(crate) async fn sync_events_v4_route( services .rooms .state_cache - .room_invited_count(room_id)? + .room_invited_count(room_id) + .await .unwrap_or(0) .try_into() .unwrap_or_else(|_| uint!(0)), @@ -1732,7 +1792,9 @@ pub(crate) async fn sync_events_v4_route( Some(sync_events::v4::ToDevice { events: services .users - .get_to_device_events(&sender_user, &sender_device)?, + .get_to_device_events(&sender_user, &sender_device) + .collect() + .await, next_batch: next_batch.to_string(), }) } else { @@ -1745,7 +1807,8 @@ pub(crate) async fn sync_events_v4_route( }, device_one_time_keys_count: services .users - .count_one_time_keys(&sender_user, &sender_device)?, + .count_one_time_keys(&sender_user, &sender_device) + .await, // Fallback keys are not yet supported device_unused_fallback_key_types: None, }, @@ -1759,25 +1822,26 @@ pub(crate) async fn sync_events_v4_route( }) } -fn filter_rooms( +async fn filter_rooms( rooms: &[OwnedRoomId], State(services): State, filter: &[RoomTypeFilter], negate: bool, ) -> Vec { - return rooms + rooms .iter() - .filter(|r| match services.rooms.state_accessor.get_room_type(r) { - Err(e) => { - warn!("Requested room type for {}, but could not retrieve with error {}", r, e); - false - }, - Ok(result) => { - let result = RoomTypeFilter::from(result); - if negate { - !filter.contains(&result) - } else { - filter.is_empty() || filter.contains(&result) - } - }, + .stream() + .filter_map(|r| async move { + match services.rooms.state_accessor.get_room_type(r).await { + Err(_) => false, + Ok(result) => { + let result = RoomTypeFilter::from(Some(result)); + if negate { + !filter.contains(&result) + } else { + filter.is_empty() || filter.contains(&result) + } + }, + } + .then_some(r.to_owned()) }) - .cloned() - .collect(); + .collect() + .await } diff --git a/src/api/client/tag.rs b/src/api/client/tag.rs index 301568e50..bcd0f8170 100644 --- a/src/api/client/tag.rs +++ b/src/api/client/tag.rs @@ -23,10 +23,11 @@ pub(crate) async fn update_tag_route( let event = services .account_data - .get(Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag)?; + .get(Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag) + .await; let mut tags_event = event.map_or_else( - || { + |_| { Ok(TagEvent { content: TagEventContent { tags: BTreeMap::new(), @@ -41,12 +42,15 @@ pub(crate) async fn update_tag_route( .tags .insert(body.tag.clone().into(), body.tag_info.clone()); - services.account_data.update( - Some(&body.room_id), - sender_user, - RoomAccountDataEventType::Tag, - &serde_json::to_value(tags_event).expect("to json value always works"), - )?; + services + .account_data + .update( + Some(&body.room_id), + sender_user, + RoomAccountDataEventType::Tag, + &serde_json::to_value(tags_event).expect("to json value always works"), + ) + .await?; Ok(create_tag::v3::Response {}) } @@ -63,10 +67,11 @@ pub(crate) async fn delete_tag_route( let event = services .account_data - .get(Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag)?; + .get(Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag) + .await; let mut tags_event = event.map_or_else( - || { + |_| { Ok(TagEvent { content: TagEventContent { tags: BTreeMap::new(), @@ -78,12 +83,15 @@ pub(crate) async fn delete_tag_route( tags_event.content.tags.remove(&body.tag.clone().into()); - services.account_data.update( - Some(&body.room_id), - sender_user, - RoomAccountDataEventType::Tag, - &serde_json::to_value(tags_event).expect("to json value always works"), - )?; + services + .account_data + .update( + Some(&body.room_id), + sender_user, + RoomAccountDataEventType::Tag, + &serde_json::to_value(tags_event).expect("to json value always works"), + ) + .await?; Ok(delete_tag::v3::Response {}) } @@ -100,10 +108,11 @@ pub(crate) async fn get_tags_route( let event = services .account_data - .get(Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag)?; + .get(Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag) + .await; let tags_event = event.map_or_else( - || { + |_| { Ok(TagEvent { content: TagEventContent { tags: BTreeMap::new(), diff --git a/src/api/client/threads.rs b/src/api/client/threads.rs index 8100f0e67..50f6cdfb2 100644 --- a/src/api/client/threads.rs +++ b/src/api/client/threads.rs @@ -1,4 +1,6 @@ use axum::extract::State; +use conduit::PduEvent; +use futures::StreamExt; use ruma::{ api::client::{error::ErrorKind, threads::get_threads}, uint, @@ -27,20 +29,23 @@ pub(crate) async fn get_threads_route( u64::MAX }; - let threads = services + let room_id = &body.room_id; + let threads: Vec<(u64, PduEvent)> = services .rooms .threads - .threads_until(sender_user, &body.room_id, from, &body.include)? + .threads_until(sender_user, &body.room_id, from, &body.include) + .await? .take(limit) - .filter_map(Result::ok) - .filter(|(_, pdu)| { + .filter_map(|(count, pdu)| async move { services .rooms .state_accessor - .user_can_see_event(sender_user, &body.room_id, &pdu.event_id) - .unwrap_or(false) + .user_can_see_event(sender_user, room_id, &pdu.event_id) + .await + .then_some((count, pdu)) }) - .collect::>(); + .collect() + .await; let next_batch = threads.last().map(|(count, _)| count.to_string()); diff --git a/src/api/client/to_device.rs b/src/api/client/to_device.rs index 1f557ad7b..2b37a9ec5 100644 --- a/src/api/client/to_device.rs +++ b/src/api/client/to_device.rs @@ -2,6 +2,7 @@ use std::collections::BTreeMap; use axum::extract::State; use conduit::{Error, Result}; +use futures::StreamExt; use ruma::{ api::{ client::{error::ErrorKind, to_device::send_event_to_device}, @@ -24,8 +25,9 @@ pub(crate) async fn send_event_to_device_route( // Check if this is a new transaction id if services .transaction_ids - .existing_txnid(sender_user, sender_device, &body.txn_id)? - .is_some() + .existing_txnid(sender_user, sender_device, &body.txn_id) + .await + .is_ok() { return Ok(send_event_to_device::v3::Response {}); } @@ -53,31 +55,35 @@ pub(crate) async fn send_event_to_device_route( continue; } + let event_type = &body.event_type.to_string(); + + let event = event + .deserialize_as() + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid"))?; + match target_device_id_maybe { DeviceIdOrAllDevices::DeviceId(target_device_id) => { - services.users.add_to_device_event( - sender_user, - target_user_id, - target_device_id, - &body.event_type.to_string(), - event - .deserialize_as() - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid"))?, - )?; + services + .users + .add_to_device_event(sender_user, target_user_id, target_device_id, event_type, event) + .await; }, DeviceIdOrAllDevices::AllDevices => { - for target_device_id in services.users.all_device_ids(target_user_id) { - services.users.add_to_device_event( - sender_user, - target_user_id, - &target_device_id?, - &body.event_type.to_string(), - event - .deserialize_as() - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid"))?, - )?; - } + let (event_type, event) = (&event_type, &event); + services + .users + .all_device_ids(target_user_id) + .for_each(|target_device_id| { + services.users.add_to_device_event( + sender_user, + target_user_id, + target_device_id, + event_type, + event.clone(), + ) + }) + .await; }, } } @@ -86,7 +92,7 @@ pub(crate) async fn send_event_to_device_route( // Save transaction id with empty data services .transaction_ids - .add_txnid(sender_user, sender_device, &body.txn_id, &[])?; + .add_txnid(sender_user, sender_device, &body.txn_id, &[]); Ok(send_event_to_device::v3::Response {}) } diff --git a/src/api/client/typing.rs b/src/api/client/typing.rs index a06648e05..932d221ed 100644 --- a/src/api/client/typing.rs +++ b/src/api/client/typing.rs @@ -16,7 +16,8 @@ pub(crate) async fn create_typing_event_route( if !services .rooms .state_cache - .is_joined(sender_user, &body.room_id)? + .is_joined(sender_user, &body.room_id) + .await { return Err(Error::BadRequest(ErrorKind::forbidden(), "You are not in this room.")); } diff --git a/src/api/client/unstable.rs b/src/api/client/unstable.rs index ab4703fdb..dc570295c 100644 --- a/src/api/client/unstable.rs +++ b/src/api/client/unstable.rs @@ -2,7 +2,8 @@ use std::collections::BTreeMap; use axum::extract::State; use axum_client_ip::InsecureClientIp; -use conduit::{warn, Err}; +use conduit::Err; +use futures::StreamExt; use ruma::{ api::{ client::{ @@ -45,7 +46,7 @@ pub(crate) async fn get_mutual_rooms_route( )); } - if !services.users.exists(&body.user_id)? { + if !services.users.exists(&body.user_id).await { return Ok(mutual_rooms::unstable::Response { joined: vec![], next_batch_token: None, @@ -55,9 +56,10 @@ pub(crate) async fn get_mutual_rooms_route( let mutual_rooms: Vec = services .rooms .user - .get_shared_rooms(vec![sender_user.clone(), body.user_id.clone()])? - .filter_map(Result::ok) - .collect(); + .get_shared_rooms(sender_user, &body.user_id) + .map(ToOwned::to_owned) + .collect() + .await; Ok(mutual_rooms::unstable::Response { joined: mutual_rooms, @@ -99,7 +101,7 @@ pub(crate) async fn get_room_summary( let room_id = services.rooms.alias.resolve(&body.room_id_or_alias).await?; - if !services.rooms.metadata.exists(&room_id)? { + if !services.rooms.metadata.exists(&room_id).await { return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server")); } @@ -108,7 +110,7 @@ pub(crate) async fn get_room_summary( .rooms .state_accessor .is_world_readable(&room_id) - .unwrap_or(false) + .await { return Err(Error::BadRequest( ErrorKind::forbidden(), @@ -122,50 +124,58 @@ pub(crate) async fn get_room_summary( .rooms .state_accessor .get_canonical_alias(&room_id) - .unwrap_or(None), + .await + .ok(), avatar_url: services .rooms .state_accessor - .get_avatar(&room_id)? + .get_avatar(&room_id) + .await .into_option() .unwrap_or_default() .url, - guest_can_join: services.rooms.state_accessor.guest_can_join(&room_id)?, - name: services - .rooms - .state_accessor - .get_name(&room_id) - .unwrap_or(None), + guest_can_join: services.rooms.state_accessor.guest_can_join(&room_id).await, + name: services.rooms.state_accessor.get_name(&room_id).await.ok(), num_joined_members: services .rooms .state_cache .room_joined_count(&room_id) - .unwrap_or_default() - .unwrap_or_else(|| { - warn!("Room {room_id} has no member count"); - 0 - }) - .try_into() - .expect("user count should not be that big"), + .await + .unwrap_or(0) + .try_into()?, topic: services .rooms .state_accessor .get_room_topic(&room_id) - .unwrap_or(None), + .await + .ok(), world_readable: services .rooms .state_accessor .is_world_readable(&room_id) - .unwrap_or(false), - join_rule: services.rooms.state_accessor.get_join_rule(&room_id)?.0, - room_type: services.rooms.state_accessor.get_room_type(&room_id)?, - room_version: Some(services.rooms.state.get_room_version(&room_id)?), + .await, + join_rule: services + .rooms + .state_accessor + .get_join_rule(&room_id) + .await + .unwrap_or_default() + .0, + room_type: services + .rooms + .state_accessor + .get_room_type(&room_id) + .await + .ok(), + room_version: services.rooms.state.get_room_version(&room_id).await.ok(), membership: if let Some(sender_user) = sender_user { services .rooms .state_accessor - .get_member(&room_id, sender_user)? - .map_or_else(|| Some(MembershipState::Leave), |content| Some(content.membership)) + .get_member(&room_id, sender_user) + .await + .map_or_else(|_| MembershipState::Leave, |content| content.membership) + .into() } else { None }, @@ -173,7 +183,8 @@ pub(crate) async fn get_room_summary( .rooms .state_accessor .get_room_encryption(&room_id) - .unwrap_or_else(|_e| None), + .await + .ok(), }) } @@ -191,13 +202,14 @@ pub(crate) async fn delete_timezone_key_route( return Err!(Request(Forbidden("You cannot update the profile of another user"))); } - services.users.set_timezone(&body.user_id, None).await?; + services.users.set_timezone(&body.user_id, None); if services.globals.allow_local_presence() { // Presence update services .presence - .ping_presence(&body.user_id, &PresenceState::Online)?; + .ping_presence(&body.user_id, &PresenceState::Online) + .await?; } Ok(delete_timezone_key::unstable::Response {}) @@ -217,16 +229,14 @@ pub(crate) async fn set_timezone_key_route( return Err!(Request(Forbidden("You cannot update the profile of another user"))); } - services - .users - .set_timezone(&body.user_id, body.tz.clone()) - .await?; + services.users.set_timezone(&body.user_id, body.tz.clone()); if services.globals.allow_local_presence() { // Presence update services .presence - .ping_presence(&body.user_id, &PresenceState::Online)?; + .ping_presence(&body.user_id, &PresenceState::Online) + .await?; } Ok(set_timezone_key::unstable::Response {}) @@ -280,10 +290,11 @@ pub(crate) async fn set_profile_key_route( .rooms .state_cache .rooms_joined(&body.user_id) - .filter_map(Result::ok) - .collect(); + .map(Into::into) + .collect() + .await; - update_displayname(&services, &body.user_id, Some(profile_key_value.to_string()), all_joined_rooms).await?; + update_displayname(&services, &body.user_id, Some(profile_key_value.to_string()), &all_joined_rooms).await?; } else if body.key == "avatar_url" { let mxc = ruma::OwnedMxcUri::from(profile_key_value.to_string()); @@ -291,21 +302,23 @@ pub(crate) async fn set_profile_key_route( .rooms .state_cache .rooms_joined(&body.user_id) - .filter_map(Result::ok) - .collect(); + .map(Into::into) + .collect() + .await; - update_avatar_url(&services, &body.user_id, Some(mxc), None, all_joined_rooms).await?; + update_avatar_url(&services, &body.user_id, Some(mxc), None, &all_joined_rooms).await?; } else { services .users - .set_profile_key(&body.user_id, &body.key, Some(profile_key_value.clone()))?; + .set_profile_key(&body.user_id, &body.key, Some(profile_key_value.clone())); } if services.globals.allow_local_presence() { // Presence update services .presence - .ping_presence(&body.user_id, &PresenceState::Online)?; + .ping_presence(&body.user_id, &PresenceState::Online) + .await?; } Ok(set_profile_key::unstable::Response {}) @@ -335,30 +348,33 @@ pub(crate) async fn delete_profile_key_route( .rooms .state_cache .rooms_joined(&body.user_id) - .filter_map(Result::ok) - .collect(); + .map(Into::into) + .collect() + .await; - update_displayname(&services, &body.user_id, None, all_joined_rooms).await?; + update_displayname(&services, &body.user_id, None, &all_joined_rooms).await?; } else if body.key == "avatar_url" { let all_joined_rooms: Vec = services .rooms .state_cache .rooms_joined(&body.user_id) - .filter_map(Result::ok) - .collect(); + .map(Into::into) + .collect() + .await; - update_avatar_url(&services, &body.user_id, None, None, all_joined_rooms).await?; + update_avatar_url(&services, &body.user_id, None, None, &all_joined_rooms).await?; } else { services .users - .set_profile_key(&body.user_id, &body.key, None)?; + .set_profile_key(&body.user_id, &body.key, None); } if services.globals.allow_local_presence() { // Presence update services .presence - .ping_presence(&body.user_id, &PresenceState::Online)?; + .ping_presence(&body.user_id, &PresenceState::Online) + .await?; } Ok(delete_profile_key::unstable::Response {}) @@ -386,26 +402,25 @@ pub(crate) async fn get_timezone_key_route( ) .await { - if !services.users.exists(&body.user_id)? { + if !services.users.exists(&body.user_id).await { services.users.create(&body.user_id, None)?; } services .users - .set_displayname(&body.user_id, response.displayname.clone()) - .await?; + .set_displayname(&body.user_id, response.displayname.clone()); + services .users - .set_avatar_url(&body.user_id, response.avatar_url.clone()) - .await?; + .set_avatar_url(&body.user_id, response.avatar_url.clone()); + services .users - .set_blurhash(&body.user_id, response.blurhash.clone()) - .await?; + .set_blurhash(&body.user_id, response.blurhash.clone()); + services .users - .set_timezone(&body.user_id, response.tz.clone()) - .await?; + .set_timezone(&body.user_id, response.tz.clone()); return Ok(get_timezone_key::unstable::Response { tz: response.tz, @@ -413,14 +428,14 @@ pub(crate) async fn get_timezone_key_route( } } - if !services.users.exists(&body.user_id)? { + if !services.users.exists(&body.user_id).await { // Return 404 if this user doesn't exist and we couldn't fetch it over // federation return Err(Error::BadRequest(ErrorKind::NotFound, "Profile was not found.")); } Ok(get_timezone_key::unstable::Response { - tz: services.users.timezone(&body.user_id)?, + tz: services.users.timezone(&body.user_id).await.ok(), }) } @@ -448,32 +463,31 @@ pub(crate) async fn get_profile_key_route( ) .await { - if !services.users.exists(&body.user_id)? { + if !services.users.exists(&body.user_id).await { services.users.create(&body.user_id, None)?; } services .users - .set_displayname(&body.user_id, response.displayname.clone()) - .await?; + .set_displayname(&body.user_id, response.displayname.clone()); + services .users - .set_avatar_url(&body.user_id, response.avatar_url.clone()) - .await?; + .set_avatar_url(&body.user_id, response.avatar_url.clone()); + services .users - .set_blurhash(&body.user_id, response.blurhash.clone()) - .await?; + .set_blurhash(&body.user_id, response.blurhash.clone()); + services .users - .set_timezone(&body.user_id, response.tz.clone()) - .await?; + .set_timezone(&body.user_id, response.tz.clone()); if let Some(value) = response.custom_profile_fields.get(&body.key) { profile_key_value.insert(body.key.clone(), value.clone()); services .users - .set_profile_key(&body.user_id, &body.key, Some(value.clone()))?; + .set_profile_key(&body.user_id, &body.key, Some(value.clone())); } else { return Err!(Request(NotFound("The requested profile key does not exist."))); } @@ -484,13 +498,13 @@ pub(crate) async fn get_profile_key_route( } } - if !services.users.exists(&body.user_id)? { + if !services.users.exists(&body.user_id).await { // Return 404 if this user doesn't exist and we couldn't fetch it over // federation - return Err(Error::BadRequest(ErrorKind::NotFound, "Profile was not found.")); + return Err!(Request(NotFound("Profile was not found."))); } - if let Some(value) = services.users.profile_key(&body.user_id, &body.key)? { + if let Ok(value) = services.users.profile_key(&body.user_id, &body.key).await { profile_key_value.insert(body.key.clone(), value); } else { return Err!(Request(NotFound("The requested profile key does not exist."))); diff --git a/src/api/client/unversioned.rs b/src/api/client/unversioned.rs index d714fda54..d5bb14e5d 100644 --- a/src/api/client/unversioned.rs +++ b/src/api/client/unversioned.rs @@ -1,6 +1,7 @@ use std::collections::BTreeMap; use axum::{extract::State, response::IntoResponse, Json}; +use futures::StreamExt; use ruma::api::client::{ discovery::{ discover_homeserver::{self, HomeserverInfo, SlidingSyncProxyInfo}, @@ -173,7 +174,7 @@ pub(crate) async fn conduwuit_server_version() -> Result { /// homeserver. Endpoint is disabled if federation is disabled for privacy. This /// only includes active users (not deactivated, no guests, etc) pub(crate) async fn conduwuit_local_user_count(State(services): State) -> Result { - let user_count = services.users.list_local_users()?.len(); + let user_count = services.users.list_local_users().count().await; Ok(Json(serde_json::json!({ "count": user_count diff --git a/src/api/client/user_directory.rs b/src/api/client/user_directory.rs index 87d4062cd..868811a3f 100644 --- a/src/api/client/user_directory.rs +++ b/src/api/client/user_directory.rs @@ -1,4 +1,6 @@ use axum::extract::State; +use conduit::utils::TryFutureExtExt; +use futures::{pin_mut, StreamExt}; use ruma::{ api::client::user_directory::search_users, events::{ @@ -21,14 +23,12 @@ pub(crate) async fn search_users_route( let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let limit = usize::try_from(body.limit).unwrap_or(10); // default limit is 10 - let mut users = services.users.iter().filter_map(|user_id| { + let users = services.users.stream().filter_map(|user_id| async { // Filter out buggy users (they should not exist, but you never know...) - let user_id = user_id.ok()?; - let user = search_users::v3::User { - user_id: user_id.clone(), - display_name: services.users.displayname(&user_id).ok()?, - avatar_url: services.users.avatar_url(&user_id).ok()?, + user_id: user_id.to_owned(), + display_name: services.users.displayname(user_id).await.ok(), + avatar_url: services.users.avatar_url(user_id).await.ok(), }; let user_id_matches = user @@ -56,20 +56,15 @@ pub(crate) async fn search_users_route( let user_is_in_public_rooms = services .rooms .state_cache - .rooms_joined(&user_id) - .filter_map(Result::ok) + .rooms_joined(&user.user_id) .any(|room| { services .rooms .state_accessor - .room_state_get(&room, &StateEventType::RoomJoinRules, "") - .map_or(false, |event| { - event.map_or(false, |event| { - serde_json::from_str(event.content.get()) - .map_or(false, |r: RoomJoinRulesEventContent| r.join_rule == JoinRule::Public) - }) - }) - }); + .room_state_get_content::(room, &StateEventType::RoomJoinRules, "") + .map_ok_or(false, |content| content.join_rule == JoinRule::Public) + }) + .await; if user_is_in_public_rooms { user_visible = true; @@ -77,25 +72,22 @@ pub(crate) async fn search_users_route( let user_is_in_shared_rooms = services .rooms .user - .get_shared_rooms(vec![sender_user.clone(), user_id]) - .ok()? - .next() - .is_some(); + .has_shared_rooms(sender_user, &user.user_id) + .await; if user_is_in_shared_rooms { user_visible = true; } } - if !user_visible { - return None; - } - - Some(user) + user_visible.then_some(user) }); - let results = users.by_ref().take(limit).collect(); - let limited = users.next().is_some(); + pin_mut!(users); + + let limited = users.by_ref().next().await.is_some(); + + let results = users.take(limit).collect().await; Ok(search_users::v3::Response { results, diff --git a/src/api/router.rs b/src/api/router.rs index 4264e01df..c4275f054 100644 --- a/src/api/router.rs +++ b/src/api/router.rs @@ -22,101 +22,101 @@ use crate::{client, server}; pub fn build(router: Router, server: &Server) -> Router { let config = &server.config; let mut router = router - .ruma_route(client::get_timezone_key_route) - .ruma_route(client::get_profile_key_route) - .ruma_route(client::set_profile_key_route) - .ruma_route(client::delete_profile_key_route) - .ruma_route(client::set_timezone_key_route) - .ruma_route(client::delete_timezone_key_route) - .ruma_route(client::appservice_ping) - .ruma_route(client::get_supported_versions_route) - .ruma_route(client::get_register_available_route) - .ruma_route(client::register_route) - .ruma_route(client::get_login_types_route) - .ruma_route(client::login_route) - .ruma_route(client::whoami_route) - .ruma_route(client::logout_route) - .ruma_route(client::logout_all_route) - .ruma_route(client::change_password_route) - .ruma_route(client::deactivate_route) - .ruma_route(client::third_party_route) - .ruma_route(client::request_3pid_management_token_via_email_route) - .ruma_route(client::request_3pid_management_token_via_msisdn_route) - .ruma_route(client::check_registration_token_validity) - .ruma_route(client::get_capabilities_route) - .ruma_route(client::get_pushrules_all_route) - .ruma_route(client::set_pushrule_route) - .ruma_route(client::get_pushrule_route) - .ruma_route(client::set_pushrule_enabled_route) - .ruma_route(client::get_pushrule_enabled_route) - .ruma_route(client::get_pushrule_actions_route) - .ruma_route(client::set_pushrule_actions_route) - .ruma_route(client::delete_pushrule_route) - .ruma_route(client::get_room_event_route) - .ruma_route(client::get_room_aliases_route) - .ruma_route(client::get_filter_route) - .ruma_route(client::create_filter_route) - .ruma_route(client::create_openid_token_route) - .ruma_route(client::set_global_account_data_route) - .ruma_route(client::set_room_account_data_route) - .ruma_route(client::get_global_account_data_route) - .ruma_route(client::get_room_account_data_route) - .ruma_route(client::set_displayname_route) - .ruma_route(client::get_displayname_route) - .ruma_route(client::set_avatar_url_route) - .ruma_route(client::get_avatar_url_route) - .ruma_route(client::get_profile_route) - .ruma_route(client::set_presence_route) - .ruma_route(client::get_presence_route) - .ruma_route(client::upload_keys_route) - .ruma_route(client::get_keys_route) - .ruma_route(client::claim_keys_route) - .ruma_route(client::create_backup_version_route) - .ruma_route(client::update_backup_version_route) - .ruma_route(client::delete_backup_version_route) - .ruma_route(client::get_latest_backup_info_route) - .ruma_route(client::get_backup_info_route) - .ruma_route(client::add_backup_keys_route) - .ruma_route(client::add_backup_keys_for_room_route) - .ruma_route(client::add_backup_keys_for_session_route) - .ruma_route(client::delete_backup_keys_for_room_route) - .ruma_route(client::delete_backup_keys_for_session_route) - .ruma_route(client::delete_backup_keys_route) - .ruma_route(client::get_backup_keys_for_room_route) - .ruma_route(client::get_backup_keys_for_session_route) - .ruma_route(client::get_backup_keys_route) - .ruma_route(client::set_read_marker_route) - .ruma_route(client::create_receipt_route) - .ruma_route(client::create_typing_event_route) - .ruma_route(client::create_room_route) - .ruma_route(client::redact_event_route) - .ruma_route(client::report_event_route) - .ruma_route(client::create_alias_route) - .ruma_route(client::delete_alias_route) - .ruma_route(client::get_alias_route) - .ruma_route(client::join_room_by_id_route) - .ruma_route(client::join_room_by_id_or_alias_route) - .ruma_route(client::joined_members_route) - .ruma_route(client::leave_room_route) - .ruma_route(client::forget_room_route) - .ruma_route(client::joined_rooms_route) - .ruma_route(client::kick_user_route) - .ruma_route(client::ban_user_route) - .ruma_route(client::unban_user_route) - .ruma_route(client::invite_user_route) - .ruma_route(client::set_room_visibility_route) - .ruma_route(client::get_room_visibility_route) - .ruma_route(client::get_public_rooms_route) - .ruma_route(client::get_public_rooms_filtered_route) - .ruma_route(client::search_users_route) - .ruma_route(client::get_member_events_route) - .ruma_route(client::get_protocols_route) + .ruma_route(&client::get_timezone_key_route) + .ruma_route(&client::get_profile_key_route) + .ruma_route(&client::set_profile_key_route) + .ruma_route(&client::delete_profile_key_route) + .ruma_route(&client::set_timezone_key_route) + .ruma_route(&client::delete_timezone_key_route) + .ruma_route(&client::appservice_ping) + .ruma_route(&client::get_supported_versions_route) + .ruma_route(&client::get_register_available_route) + .ruma_route(&client::register_route) + .ruma_route(&client::get_login_types_route) + .ruma_route(&client::login_route) + .ruma_route(&client::whoami_route) + .ruma_route(&client::logout_route) + .ruma_route(&client::logout_all_route) + .ruma_route(&client::change_password_route) + .ruma_route(&client::deactivate_route) + .ruma_route(&client::third_party_route) + .ruma_route(&client::request_3pid_management_token_via_email_route) + .ruma_route(&client::request_3pid_management_token_via_msisdn_route) + .ruma_route(&client::check_registration_token_validity) + .ruma_route(&client::get_capabilities_route) + .ruma_route(&client::get_pushrules_all_route) + .ruma_route(&client::set_pushrule_route) + .ruma_route(&client::get_pushrule_route) + .ruma_route(&client::set_pushrule_enabled_route) + .ruma_route(&client::get_pushrule_enabled_route) + .ruma_route(&client::get_pushrule_actions_route) + .ruma_route(&client::set_pushrule_actions_route) + .ruma_route(&client::delete_pushrule_route) + .ruma_route(&client::get_room_event_route) + .ruma_route(&client::get_room_aliases_route) + .ruma_route(&client::get_filter_route) + .ruma_route(&client::create_filter_route) + .ruma_route(&client::create_openid_token_route) + .ruma_route(&client::set_global_account_data_route) + .ruma_route(&client::set_room_account_data_route) + .ruma_route(&client::get_global_account_data_route) + .ruma_route(&client::get_room_account_data_route) + .ruma_route(&client::set_displayname_route) + .ruma_route(&client::get_displayname_route) + .ruma_route(&client::set_avatar_url_route) + .ruma_route(&client::get_avatar_url_route) + .ruma_route(&client::get_profile_route) + .ruma_route(&client::set_presence_route) + .ruma_route(&client::get_presence_route) + .ruma_route(&client::upload_keys_route) + .ruma_route(&client::get_keys_route) + .ruma_route(&client::claim_keys_route) + .ruma_route(&client::create_backup_version_route) + .ruma_route(&client::update_backup_version_route) + .ruma_route(&client::delete_backup_version_route) + .ruma_route(&client::get_latest_backup_info_route) + .ruma_route(&client::get_backup_info_route) + .ruma_route(&client::add_backup_keys_route) + .ruma_route(&client::add_backup_keys_for_room_route) + .ruma_route(&client::add_backup_keys_for_session_route) + .ruma_route(&client::delete_backup_keys_for_room_route) + .ruma_route(&client::delete_backup_keys_for_session_route) + .ruma_route(&client::delete_backup_keys_route) + .ruma_route(&client::get_backup_keys_for_room_route) + .ruma_route(&client::get_backup_keys_for_session_route) + .ruma_route(&client::get_backup_keys_route) + .ruma_route(&client::set_read_marker_route) + .ruma_route(&client::create_receipt_route) + .ruma_route(&client::create_typing_event_route) + .ruma_route(&client::create_room_route) + .ruma_route(&client::redact_event_route) + .ruma_route(&client::report_event_route) + .ruma_route(&client::create_alias_route) + .ruma_route(&client::delete_alias_route) + .ruma_route(&client::get_alias_route) + .ruma_route(&client::join_room_by_id_route) + .ruma_route(&client::join_room_by_id_or_alias_route) + .ruma_route(&client::joined_members_route) + .ruma_route(&client::leave_room_route) + .ruma_route(&client::forget_room_route) + .ruma_route(&client::joined_rooms_route) + .ruma_route(&client::kick_user_route) + .ruma_route(&client::ban_user_route) + .ruma_route(&client::unban_user_route) + .ruma_route(&client::invite_user_route) + .ruma_route(&client::set_room_visibility_route) + .ruma_route(&client::get_room_visibility_route) + .ruma_route(&client::get_public_rooms_route) + .ruma_route(&client::get_public_rooms_filtered_route) + .ruma_route(&client::search_users_route) + .ruma_route(&client::get_member_events_route) + .ruma_route(&client::get_protocols_route) .route("/_matrix/client/unstable/thirdparty/protocols", get(client::get_protocols_route_unstable)) - .ruma_route(client::send_message_event_route) - .ruma_route(client::send_state_event_for_key_route) - .ruma_route(client::get_state_events_route) - .ruma_route(client::get_state_events_for_key_route) + .ruma_route(&client::send_message_event_route) + .ruma_route(&client::send_state_event_for_key_route) + .ruma_route(&client::get_state_events_route) + .ruma_route(&client::get_state_events_for_key_route) // Ruma doesn't have support for multiple paths for a single endpoint yet, and these routes // share one Ruma request / response type pair with {get,send}_state_event_for_key_route .route( @@ -140,46 +140,46 @@ pub fn build(router: Router, server: &Server) -> Router { get(client::get_state_events_for_empty_key_route) .put(client::send_state_event_for_empty_key_route), ) - .ruma_route(client::sync_events_route) - .ruma_route(client::sync_events_v4_route) - .ruma_route(client::get_context_route) - .ruma_route(client::get_message_events_route) - .ruma_route(client::search_events_route) - .ruma_route(client::turn_server_route) - .ruma_route(client::send_event_to_device_route) - .ruma_route(client::create_content_route) - .ruma_route(client::get_content_thumbnail_route) - .ruma_route(client::get_content_route) - .ruma_route(client::get_content_as_filename_route) - .ruma_route(client::get_media_preview_route) - .ruma_route(client::get_media_config_route) - .ruma_route(client::get_devices_route) - .ruma_route(client::get_device_route) - .ruma_route(client::update_device_route) - .ruma_route(client::delete_device_route) - .ruma_route(client::delete_devices_route) - .ruma_route(client::get_tags_route) - .ruma_route(client::update_tag_route) - .ruma_route(client::delete_tag_route) - .ruma_route(client::upload_signing_keys_route) - .ruma_route(client::upload_signatures_route) - .ruma_route(client::get_key_changes_route) - .ruma_route(client::get_pushers_route) - .ruma_route(client::set_pushers_route) - .ruma_route(client::upgrade_room_route) - .ruma_route(client::get_threads_route) - .ruma_route(client::get_relating_events_with_rel_type_and_event_type_route) - .ruma_route(client::get_relating_events_with_rel_type_route) - .ruma_route(client::get_relating_events_route) - .ruma_route(client::get_hierarchy_route) - .ruma_route(client::get_mutual_rooms_route) - .ruma_route(client::get_room_summary) + .ruma_route(&client::sync_events_route) + .ruma_route(&client::sync_events_v4_route) + .ruma_route(&client::get_context_route) + .ruma_route(&client::get_message_events_route) + .ruma_route(&client::search_events_route) + .ruma_route(&client::turn_server_route) + .ruma_route(&client::send_event_to_device_route) + .ruma_route(&client::create_content_route) + .ruma_route(&client::get_content_thumbnail_route) + .ruma_route(&client::get_content_route) + .ruma_route(&client::get_content_as_filename_route) + .ruma_route(&client::get_media_preview_route) + .ruma_route(&client::get_media_config_route) + .ruma_route(&client::get_devices_route) + .ruma_route(&client::get_device_route) + .ruma_route(&client::update_device_route) + .ruma_route(&client::delete_device_route) + .ruma_route(&client::delete_devices_route) + .ruma_route(&client::get_tags_route) + .ruma_route(&client::update_tag_route) + .ruma_route(&client::delete_tag_route) + .ruma_route(&client::upload_signing_keys_route) + .ruma_route(&client::upload_signatures_route) + .ruma_route(&client::get_key_changes_route) + .ruma_route(&client::get_pushers_route) + .ruma_route(&client::set_pushers_route) + .ruma_route(&client::upgrade_room_route) + .ruma_route(&client::get_threads_route) + .ruma_route(&client::get_relating_events_with_rel_type_and_event_type_route) + .ruma_route(&client::get_relating_events_with_rel_type_route) + .ruma_route(&client::get_relating_events_route) + .ruma_route(&client::get_hierarchy_route) + .ruma_route(&client::get_mutual_rooms_route) + .ruma_route(&client::get_room_summary) .route( "/_matrix/client/unstable/im.nheko.summary/rooms/:room_id_or_alias/summary", get(client::get_room_summary_legacy) ) - .ruma_route(client::well_known_support) - .ruma_route(client::well_known_client) + .ruma_route(&client::well_known_support) + .ruma_route(&client::well_known_client) .route("/_conduwuit/server_version", get(client::conduwuit_server_version)) .route("/_matrix/client/r0/rooms/:room_id/initialSync", get(initial_sync)) .route("/_matrix/client/v3/rooms/:room_id/initialSync", get(initial_sync)) @@ -187,35 +187,35 @@ pub fn build(router: Router, server: &Server) -> Router { if config.allow_federation { router = router - .ruma_route(server::get_server_version_route) + .ruma_route(&server::get_server_version_route) .route("/_matrix/key/v2/server", get(server::get_server_keys_route)) .route("/_matrix/key/v2/server/:key_id", get(server::get_server_keys_deprecated_route)) - .ruma_route(server::get_public_rooms_route) - .ruma_route(server::get_public_rooms_filtered_route) - .ruma_route(server::send_transaction_message_route) - .ruma_route(server::get_event_route) - .ruma_route(server::get_backfill_route) - .ruma_route(server::get_missing_events_route) - .ruma_route(server::get_event_authorization_route) - .ruma_route(server::get_room_state_route) - .ruma_route(server::get_room_state_ids_route) - .ruma_route(server::create_leave_event_template_route) - .ruma_route(server::create_leave_event_v1_route) - .ruma_route(server::create_leave_event_v2_route) - .ruma_route(server::create_join_event_template_route) - .ruma_route(server::create_join_event_v1_route) - .ruma_route(server::create_join_event_v2_route) - .ruma_route(server::create_invite_route) - .ruma_route(server::get_devices_route) - .ruma_route(server::get_room_information_route) - .ruma_route(server::get_profile_information_route) - .ruma_route(server::get_keys_route) - .ruma_route(server::claim_keys_route) - .ruma_route(server::get_openid_userinfo_route) - .ruma_route(server::get_hierarchy_route) - .ruma_route(server::well_known_server) - .ruma_route(server::get_content_route) - .ruma_route(server::get_content_thumbnail_route) + .ruma_route(&server::get_public_rooms_route) + .ruma_route(&server::get_public_rooms_filtered_route) + .ruma_route(&server::send_transaction_message_route) + .ruma_route(&server::get_event_route) + .ruma_route(&server::get_backfill_route) + .ruma_route(&server::get_missing_events_route) + .ruma_route(&server::get_event_authorization_route) + .ruma_route(&server::get_room_state_route) + .ruma_route(&server::get_room_state_ids_route) + .ruma_route(&server::create_leave_event_template_route) + .ruma_route(&server::create_leave_event_v1_route) + .ruma_route(&server::create_leave_event_v2_route) + .ruma_route(&server::create_join_event_template_route) + .ruma_route(&server::create_join_event_v1_route) + .ruma_route(&server::create_join_event_v2_route) + .ruma_route(&server::create_invite_route) + .ruma_route(&server::get_devices_route) + .ruma_route(&server::get_room_information_route) + .ruma_route(&server::get_profile_information_route) + .ruma_route(&server::get_keys_route) + .ruma_route(&server::claim_keys_route) + .ruma_route(&server::get_openid_userinfo_route) + .ruma_route(&server::get_hierarchy_route) + .ruma_route(&server::well_known_server) + .ruma_route(&server::get_content_route) + .ruma_route(&server::get_content_thumbnail_route) .route("/_conduwuit/local_user_count", get(client::conduwuit_local_user_count)); } else { router = router @@ -227,11 +227,11 @@ pub fn build(router: Router, server: &Server) -> Router { if config.allow_legacy_media { router = router - .ruma_route(client::get_media_config_legacy_route) - .ruma_route(client::get_media_preview_legacy_route) - .ruma_route(client::get_content_legacy_route) - .ruma_route(client::get_content_as_filename_legacy_route) - .ruma_route(client::get_content_thumbnail_legacy_route) + .ruma_route(&client::get_media_config_legacy_route) + .ruma_route(&client::get_media_preview_legacy_route) + .ruma_route(&client::get_content_legacy_route) + .ruma_route(&client::get_content_as_filename_legacy_route) + .ruma_route(&client::get_content_thumbnail_legacy_route) .route("/_matrix/media/v1/config", get(client::get_media_config_legacy_legacy_route)) .route("/_matrix/media/v1/upload", post(client::create_content_legacy_route)) .route( diff --git a/src/api/router/args.rs b/src/api/router/args.rs index a3d09dff5..7381a55f5 100644 --- a/src/api/router/args.rs +++ b/src/api/router/args.rs @@ -10,7 +10,10 @@ use super::{auth, auth::Auth, request, request::Request}; use crate::{service::appservice::RegistrationInfo, State}; /// Extractor for Ruma request structs -pub(crate) struct Args { +pub(crate) struct Args +where + T: IncomingRequest + Send + Sync + 'static, +{ /// Request struct body pub(crate) body: T, @@ -38,7 +41,7 @@ pub(crate) struct Args { #[async_trait] impl FromRequest for Args where - T: IncomingRequest, + T: IncomingRequest + Send + Sync + 'static, { type Rejection = Error; @@ -57,7 +60,10 @@ where } } -impl Deref for Args { +impl Deref for Args +where + T: IncomingRequest + Send + Sync + 'static, +{ type Target = T; fn deref(&self) -> &Self::Target { &self.body } @@ -67,7 +73,7 @@ fn make_body( services: &Services, request: &mut Request, json_body: &mut Option, auth: &Auth, ) -> Result where - T: IncomingRequest, + T: IncomingRequest + Send + Sync + 'static, { let body = if let Some(CanonicalJsonValue::Object(json_body)) = json_body { let user_id = auth.sender_user.clone().unwrap_or_else(|| { @@ -77,15 +83,13 @@ where let uiaa_request = json_body .get("auth") - .and_then(|auth| auth.as_object()) + .and_then(CanonicalJsonValue::as_object) .and_then(|auth| auth.get("session")) - .and_then(|session| session.as_str()) + .and_then(CanonicalJsonValue::as_str) .and_then(|session| { - services.uiaa.get_uiaa_request( - &user_id, - &auth.sender_device.clone().unwrap_or_else(|| EMPTY.into()), - session, - ) + services + .uiaa + .get_uiaa_request(&user_id, auth.sender_device.as_deref(), session) }); if let Some(CanonicalJsonValue::Object(initial_request)) = uiaa_request { diff --git a/src/api/router/auth.rs b/src/api/router/auth.rs index 670f72ba8..8d76b4be8 100644 --- a/src/api/router/auth.rs +++ b/src/api/router/auth.rs @@ -44,8 +44,8 @@ pub(super) async fn auth( let token = if let Some(token) = token { if let Some(reg_info) = services.appservice.find_from_token(token).await { Token::Appservice(Box::new(reg_info)) - } else if let Some((user_id, device_id)) = services.users.find_from_token(token)? { - Token::User((user_id, OwnedDeviceId::from(device_id))) + } else if let Ok((user_id, device_id)) = services.users.find_from_token(token).await { + Token::User((user_id, device_id)) } else { Token::Invalid } @@ -98,7 +98,7 @@ pub(super) async fn auth( )) } }, - (AuthScheme::AccessToken, Token::Appservice(info)) => Ok(auth_appservice(services, request, info)?), + (AuthScheme::AccessToken, Token::Appservice(info)) => Ok(auth_appservice(services, request, info).await?), (AuthScheme::None | AuthScheme::AccessTokenOptional | AuthScheme::AppserviceToken, Token::Appservice(info)) => { Ok(Auth { origin: None, @@ -150,7 +150,7 @@ pub(super) async fn auth( } } -fn auth_appservice(services: &Services, request: &Request, info: Box) -> Result { +async fn auth_appservice(services: &Services, request: &Request, info: Box) -> Result { let user_id = request .query .user_id @@ -170,7 +170,7 @@ fn auth_appservice(services: &Services, request: &Request, info: Box { + fn add_route(&'static self, router: Router, path: &str) -> Router; + fn add_routes(&'static self, router: Router) -> Router; +} + pub(in super::super) trait RouterExt { - fn ruma_route(self, handler: H) -> Self + fn ruma_route(self, handler: &'static H) -> Self where H: RumaHandler; } impl RouterExt for Router { - fn ruma_route(self, handler: H) -> Self + fn ruma_route(self, handler: &'static H) -> Self where H: RumaHandler, { @@ -27,34 +31,28 @@ impl RouterExt for Router { } } -pub(in super::super) trait RumaHandler { - fn add_routes(&self, router: Router) -> Router; - - fn add_route(&self, router: Router, path: &str) -> Router; -} - macro_rules! ruma_handler { ( $($tx:ident),* $(,)? ) => { #[allow(non_snake_case)] - impl RumaHandler<($($tx,)* Ruma,)> for Fun + impl RumaHandler<($($tx,)* Ruma,)> for Fun where - Req: IncomingRequest + Send + 'static, - Ret: IntoResponse, - Fut: Future> + Send, - Fun: FnOnce($($tx,)* Ruma,) -> Fut + Clone + Send + Sync + 'static, - $( $tx: FromRequestParts + Send + 'static, )* + Fun: Fn($($tx,)* Ruma,) -> Fut + Send + Sync + 'static, + Fut: Future> + Send, + Req: IncomingRequest + Send + Sync, + Err: IntoResponse + Send, + ::OutgoingResponse: Send, + $( $tx: FromRequestParts + Send + Sync + 'static, )* { - fn add_routes(&self, router: Router) -> Router { + fn add_routes(&'static self, router: Router) -> Router { Req::METADATA .history .all_paths() .fold(router, |router, path| self.add_route(router, path)) } - fn add_route(&self, router: Router, path: &str) -> Router { - let handle = self.clone(); + fn add_route(&'static self, router: Router, path: &str) -> Router { + let action = |$($tx,)* req| self($($tx,)* req).map_ok(RumaResponse); let method = method_to_filter(&Req::METADATA.method); - let action = |$($tx,)* req| async { handle($($tx,)* req).await.map(RumaResponse) }; router.route(path, on(method, action)) } } diff --git a/src/api/router/response.rs b/src/api/router/response.rs index 2aaa79faa..70bbb9364 100644 --- a/src/api/router/response.rs +++ b/src/api/router/response.rs @@ -5,13 +5,18 @@ use http::StatusCode; use http_body_util::Full; use ruma::api::{client::uiaa::UiaaResponse, OutgoingResponse}; -pub(crate) struct RumaResponse(pub(crate) T); +pub(crate) struct RumaResponse(pub(crate) T) +where + T: OutgoingResponse; impl From for RumaResponse { fn from(t: Error) -> Self { Self(t.into()) } } -impl IntoResponse for RumaResponse { +impl IntoResponse for RumaResponse +where + T: OutgoingResponse, +{ fn into_response(self) -> Response { self.0 .try_into_http_response::() diff --git a/src/api/server/backfill.rs b/src/api/server/backfill.rs index 1b665c19d..2bbc95ca9 100644 --- a/src/api/server/backfill.rs +++ b/src/api/server/backfill.rs @@ -1,9 +1,13 @@ +use std::cmp; + use axum::extract::State; -use conduit::{Error, Result}; -use ruma::{ - api::{client::error::ErrorKind, federation::backfill::get_backfill}, - uint, user_id, MilliSecondsSinceUnixEpoch, +use conduit::{ + is_equal_to, + utils::{IterStream, ReadyExt}, + Err, PduCount, Result, }; +use futures::{FutureExt, StreamExt}; +use ruma::{api::federation::backfill::get_backfill, uint, user_id, MilliSecondsSinceUnixEpoch}; use crate::Ruma; @@ -19,27 +23,35 @@ pub(crate) async fn get_backfill_route( services .rooms .event_handler - .acl_check(origin, &body.room_id)?; + .acl_check(origin, &body.room_id) + .await?; if !services .rooms .state_accessor - .is_world_readable(&body.room_id)? - && !services - .rooms - .state_cache - .server_in_room(origin, &body.room_id)? + .is_world_readable(&body.room_id) + .await && !services + .rooms + .state_cache + .server_in_room(origin, &body.room_id) + .await { - return Err(Error::BadRequest(ErrorKind::forbidden(), "Server is not in room.")); + return Err!(Request(Forbidden("Server is not in room."))); } let until = body .v .iter() - .map(|event_id| services.rooms.timeline.get_pdu_count(event_id)) - .filter_map(|r| r.ok().flatten()) - .max() - .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Event not found."))?; + .stream() + .filter_map(|event_id| { + services + .rooms + .timeline + .get_pdu_count(event_id) + .map(Result::ok) + }) + .ready_fold(PduCount::Backfilled(0), cmp::max) + .await; let limit = body .limit @@ -47,31 +59,37 @@ pub(crate) async fn get_backfill_route( .try_into() .expect("UInt could not be converted to usize"); - let all_events = services + let pdus = services .rooms .timeline - .pdus_until(user_id!("@doesntmatter:conduit.rs"), &body.room_id, until)? - .take(limit); + .pdus_until(user_id!("@doesntmatter:conduit.rs"), &body.room_id, until) + .await? + .take(limit) + .filter_map(|(_, pdu)| async move { + if !services + .rooms + .state_accessor + .server_can_see_event(origin, &pdu.room_id, &pdu.event_id) + .await + .is_ok_and(is_equal_to!(true)) + { + return None; + } - let events = all_events - .filter_map(Result::ok) - .filter(|(_, e)| { - matches!( - services - .rooms - .state_accessor - .server_can_see_event(origin, &e.room_id, &e.event_id,), - Ok(true), - ) + services + .rooms + .timeline + .get_pdu_json(&pdu.event_id) + .await + .ok() }) - .map(|(_, pdu)| services.rooms.timeline.get_pdu_json(&pdu.event_id)) - .filter_map(|r| r.ok().flatten()) - .map(|pdu| services.sending.convert_to_outgoing_federation_event(pdu)) - .collect(); + .then(|pdu| services.sending.convert_to_outgoing_federation_event(pdu)) + .collect() + .await; Ok(get_backfill::v1::Response { origin: services.globals.server_name().to_owned(), origin_server_ts: MilliSecondsSinceUnixEpoch::now(), - pdus: events, + pdus, }) } diff --git a/src/api/server/event.rs b/src/api/server/event.rs index e11a01a20..e4eac794f 100644 --- a/src/api/server/event.rs +++ b/src/api/server/event.rs @@ -1,9 +1,6 @@ use axum::extract::State; -use conduit::{Error, Result}; -use ruma::{ - api::{client::error::ErrorKind, federation::event::get_event}, - MilliSecondsSinceUnixEpoch, RoomId, -}; +use conduit::{err, Err, Result}; +use ruma::{api::federation::event::get_event, MilliSecondsSinceUnixEpoch, RoomId}; use crate::Ruma; @@ -21,34 +18,46 @@ pub(crate) async fn get_event_route( let event = services .rooms .timeline - .get_pdu_json(&body.event_id)? - .ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Event not found."))?; + .get_pdu_json(&body.event_id) + .await + .map_err(|_| err!(Request(NotFound("Event not found."))))?; let room_id_str = event .get("room_id") .and_then(|val| val.as_str()) - .ok_or_else(|| Error::bad_database("Invalid event in database."))?; + .ok_or_else(|| err!(Database("Invalid event in database.")))?; let room_id = - <&RoomId>::try_from(room_id_str).map_err(|_| Error::bad_database("Invalid room_id in event in database."))?; + <&RoomId>::try_from(room_id_str).map_err(|_| err!(Database("Invalid room_id in event in database.")))?; - if !services.rooms.state_accessor.is_world_readable(room_id)? - && !services.rooms.state_cache.server_in_room(origin, room_id)? + if !services + .rooms + .state_accessor + .is_world_readable(room_id) + .await && !services + .rooms + .state_cache + .server_in_room(origin, room_id) + .await { - return Err(Error::BadRequest(ErrorKind::forbidden(), "Server is not in room.")); + return Err!(Request(Forbidden("Server is not in room."))); } if !services .rooms .state_accessor - .server_can_see_event(origin, room_id, &body.event_id)? + .server_can_see_event(origin, room_id, &body.event_id) + .await? { - return Err(Error::BadRequest(ErrorKind::forbidden(), "Server is not allowed to see event.")); + return Err!(Request(Forbidden("Server is not allowed to see event."))); } Ok(get_event::v1::Response { origin: services.globals.server_name().to_owned(), origin_server_ts: MilliSecondsSinceUnixEpoch::now(), - pdu: services.sending.convert_to_outgoing_federation_event(event), + pdu: services + .sending + .convert_to_outgoing_federation_event(event) + .await, }) } diff --git a/src/api/server/event_auth.rs b/src/api/server/event_auth.rs index 4b0f6bc00..8307a4ad3 100644 --- a/src/api/server/event_auth.rs +++ b/src/api/server/event_auth.rs @@ -1,7 +1,8 @@ -use std::sync::Arc; +use std::borrow::Borrow; use axum::extract::State; use conduit::{Error, Result}; +use futures::StreamExt; use ruma::{ api::{client::error::ErrorKind, federation::authorization::get_event_authorization}, RoomId, @@ -22,16 +23,18 @@ pub(crate) async fn get_event_authorization_route( services .rooms .event_handler - .acl_check(origin, &body.room_id)?; + .acl_check(origin, &body.room_id) + .await?; if !services .rooms .state_accessor - .is_world_readable(&body.room_id)? - && !services - .rooms - .state_cache - .server_in_room(origin, &body.room_id)? + .is_world_readable(&body.room_id) + .await && !services + .rooms + .state_cache + .server_in_room(origin, &body.room_id) + .await { return Err(Error::BadRequest(ErrorKind::forbidden(), "Server is not in room.")); } @@ -39,8 +42,9 @@ pub(crate) async fn get_event_authorization_route( let event = services .rooms .timeline - .get_pdu_json(&body.event_id)? - .ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Event not found."))?; + .get_pdu_json(&body.event_id) + .await + .map_err(|_| Error::BadRequest(ErrorKind::NotFound, "Event not found."))?; let room_id_str = event .get("room_id") @@ -50,16 +54,17 @@ pub(crate) async fn get_event_authorization_route( let room_id = <&RoomId>::try_from(room_id_str).map_err(|_| Error::bad_database("Invalid room_id in event in database."))?; - let auth_chain_ids = services + let auth_chain = services .rooms .auth_chain - .event_ids_iter(room_id, vec![Arc::from(&*body.event_id)]) - .await?; + .event_ids_iter(room_id, &[body.event_id.borrow()]) + .await? + .filter_map(|id| async move { services.rooms.timeline.get_pdu_json(&id).await.ok() }) + .then(|pdu| services.sending.convert_to_outgoing_federation_event(pdu)) + .collect() + .await; Ok(get_event_authorization::v1::Response { - auth_chain: auth_chain_ids - .filter_map(|id| services.rooms.timeline.get_pdu_json(&id).ok()?) - .map(|pdu| services.sending.convert_to_outgoing_federation_event(pdu)) - .collect(), + auth_chain, }) } diff --git a/src/api/server/get_missing_events.rs b/src/api/server/get_missing_events.rs index e2c3c93cf..7ae0ff608 100644 --- a/src/api/server/get_missing_events.rs +++ b/src/api/server/get_missing_events.rs @@ -18,16 +18,18 @@ pub(crate) async fn get_missing_events_route( services .rooms .event_handler - .acl_check(origin, &body.room_id)?; + .acl_check(origin, &body.room_id) + .await?; if !services .rooms .state_accessor - .is_world_readable(&body.room_id)? - && !services - .rooms - .state_cache - .server_in_room(origin, &body.room_id)? + .is_world_readable(&body.room_id) + .await && !services + .rooms + .state_cache + .server_in_room(origin, &body.room_id) + .await { return Err(Error::BadRequest(ErrorKind::forbidden(), "Server is not in room")); } @@ -43,7 +45,12 @@ pub(crate) async fn get_missing_events_route( let mut i: usize = 0; while i < queued_events.len() && events.len() < limit { - if let Some(pdu) = services.rooms.timeline.get_pdu_json(&queued_events[i])? { + if let Ok(pdu) = services + .rooms + .timeline + .get_pdu_json(&queued_events[i]) + .await + { let room_id_str = pdu .get("room_id") .and_then(|val| val.as_str()) @@ -64,7 +71,8 @@ pub(crate) async fn get_missing_events_route( if !services .rooms .state_accessor - .server_can_see_event(origin, &body.room_id, &queued_events[i])? + .server_can_see_event(origin, &body.room_id, &queued_events[i]) + .await? { i = i.saturating_add(1); continue; @@ -81,7 +89,12 @@ pub(crate) async fn get_missing_events_route( ) .map_err(|_| Error::bad_database("Invalid prev_events in event in database."))?, ); - events.push(services.sending.convert_to_outgoing_federation_event(pdu)); + events.push( + services + .sending + .convert_to_outgoing_federation_event(pdu) + .await, + ); } i = i.saturating_add(1); } diff --git a/src/api/server/hierarchy.rs b/src/api/server/hierarchy.rs index 530ed1456..002bd7633 100644 --- a/src/api/server/hierarchy.rs +++ b/src/api/server/hierarchy.rs @@ -12,7 +12,7 @@ pub(crate) async fn get_hierarchy_route( ) -> Result { let origin = body.origin.as_ref().expect("server is authenticated"); - if services.rooms.metadata.exists(&body.room_id)? { + if services.rooms.metadata.exists(&body.room_id).await { services .rooms .spaces diff --git a/src/api/server/invite.rs b/src/api/server/invite.rs index 688e026c5..9968bdf72 100644 --- a/src/api/server/invite.rs +++ b/src/api/server/invite.rs @@ -24,7 +24,8 @@ pub(crate) async fn create_invite_route( services .rooms .event_handler - .acl_check(origin, &body.room_id)?; + .acl_check(origin, &body.room_id) + .await?; if !services .globals @@ -98,7 +99,8 @@ pub(crate) async fn create_invite_route( services .rooms .event_handler - .acl_check(invited_user.server_name(), &body.room_id)?; + .acl_check(invited_user.server_name(), &body.room_id) + .await?; ruma::signatures::hash_and_sign_event( services.globals.server_name().as_str(), @@ -128,14 +130,14 @@ pub(crate) async fn create_invite_route( ) .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "sender is not a user ID."))?; - if services.rooms.metadata.is_banned(&body.room_id)? && !services.users.is_admin(&invited_user)? { + if services.rooms.metadata.is_banned(&body.room_id).await && !services.users.is_admin(&invited_user).await { return Err(Error::BadRequest( ErrorKind::forbidden(), "This room is banned on this homeserver.", )); } - if services.globals.block_non_admin_invites() && !services.users.is_admin(&invited_user)? { + if services.globals.block_non_admin_invites() && !services.users.is_admin(&invited_user).await { return Err(Error::BadRequest( ErrorKind::forbidden(), "This server does not allow room invites.", @@ -159,22 +161,28 @@ pub(crate) async fn create_invite_route( if !services .rooms .state_cache - .server_in_room(services.globals.server_name(), &body.room_id)? + .server_in_room(services.globals.server_name(), &body.room_id) + .await { - services.rooms.state_cache.update_membership( - &body.room_id, - &invited_user, - RoomMemberEventContent::new(MembershipState::Invite), - &sender, - Some(invite_state), - body.via.clone(), - true, - )?; + services + .rooms + .state_cache + .update_membership( + &body.room_id, + &invited_user, + RoomMemberEventContent::new(MembershipState::Invite), + &sender, + Some(invite_state), + body.via.clone(), + true, + ) + .await?; } Ok(create_invite::v2::Response { event: services .sending - .convert_to_outgoing_federation_event(signed_event), + .convert_to_outgoing_federation_event(signed_event) + .await, }) } diff --git a/src/api/server/make_join.rs b/src/api/server/make_join.rs index 021016be2..ba081aade 100644 --- a/src/api/server/make_join.rs +++ b/src/api/server/make_join.rs @@ -1,4 +1,6 @@ use axum::extract::State; +use conduit::utils::{IterStream, ReadyExt}; +use futures::StreamExt; use ruma::{ api::{client::error::ErrorKind, federation::membership::prepare_join_event}, events::{ @@ -24,7 +26,7 @@ use crate::{ pub(crate) async fn create_join_event_template_route( State(services): State, body: Ruma, ) -> Result { - if !services.rooms.metadata.exists(&body.room_id)? { + if !services.rooms.metadata.exists(&body.room_id).await { return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server.")); } @@ -40,7 +42,8 @@ pub(crate) async fn create_join_event_template_route( services .rooms .event_handler - .acl_check(origin, &body.room_id)?; + .acl_check(origin, &body.room_id) + .await?; if services .globals @@ -73,7 +76,7 @@ pub(crate) async fn create_join_event_template_route( } } - let room_version_id = services.rooms.state.get_room_version(&body.room_id)?; + let room_version_id = services.rooms.state.get_room_version(&body.room_id).await?; let state_lock = services.rooms.state.mutex.lock(&body.room_id).await; @@ -81,22 +84,24 @@ pub(crate) async fn create_join_event_template_route( .rooms .state_cache .is_left(&body.user_id, &body.room_id) - .unwrap_or(true)) - && user_can_perform_restricted_join(&services, &body.user_id, &body.room_id, &room_version_id)? + .await) + && user_can_perform_restricted_join(&services, &body.user_id, &body.room_id, &room_version_id).await? { let auth_user = services .rooms .state_cache .room_members(&body.room_id) - .filter_map(Result::ok) - .filter(|user| user.server_name() == services.globals.server_name()) - .find(|user| { + .ready_filter(|user| user.server_name() == services.globals.server_name()) + .filter(|user| { services .rooms .state_accessor .user_can_invite(&body.room_id, user, &body.user_id, &state_lock) - .unwrap_or(false) - }); + }) + .boxed() + .next() + .await + .map(ToOwned::to_owned); if auth_user.is_some() { auth_user @@ -110,7 +115,7 @@ pub(crate) async fn create_join_event_template_route( None }; - let room_version_id = services.rooms.state.get_room_version(&body.room_id)?; + let room_version_id = services.rooms.state.get_room_version(&body.room_id).await?; if !body.ver.contains(&room_version_id) { return Err(Error::BadRequest( ErrorKind::IncompatibleRoomVersion { @@ -132,19 +137,23 @@ pub(crate) async fn create_join_event_template_route( }) .expect("member event is valid value"); - let (_pdu, mut pdu_json) = services.rooms.timeline.create_hash_and_sign_event( - PduBuilder { - event_type: TimelineEventType::RoomMember, - content, - unsigned: None, - state_key: Some(body.user_id.to_string()), - redacts: None, - timestamp: None, - }, - &body.user_id, - &body.room_id, - &state_lock, - )?; + let (_pdu, mut pdu_json) = services + .rooms + .timeline + .create_hash_and_sign_event( + PduBuilder { + event_type: TimelineEventType::RoomMember, + content, + unsigned: None, + state_key: Some(body.user_id.to_string()), + redacts: None, + timestamp: None, + }, + &body.user_id, + &body.room_id, + &state_lock, + ) + .await?; drop(state_lock); @@ -161,7 +170,7 @@ pub(crate) async fn create_join_event_template_route( /// This doesn't check the current user's membership. This should be done /// externally, either by using the state cache or attempting to authorize the /// event. -pub(crate) fn user_can_perform_restricted_join( +pub(crate) async fn user_can_perform_restricted_join( services: &Services, user_id: &UserId, room_id: &RoomId, room_version_id: &RoomVersionId, ) -> Result { use RoomVersionId::*; @@ -169,18 +178,15 @@ pub(crate) fn user_can_perform_restricted_join( let join_rules_event = services .rooms .state_accessor - .room_state_get(room_id, &StateEventType::RoomJoinRules, "")?; - - let Some(join_rules_event_content) = join_rules_event - .as_ref() - .map(|join_rules_event| { - serde_json::from_str::(join_rules_event.content.get()).map_err(|e| { - warn!("Invalid join rules event in database: {e}"); - Error::bad_database("Invalid join rules event in database") - }) + .room_state_get(room_id, &StateEventType::RoomJoinRules, "") + .await; + + let Ok(Ok(join_rules_event_content)) = join_rules_event.as_ref().map(|join_rules_event| { + serde_json::from_str::(join_rules_event.content.get()).map_err(|e| { + warn!("Invalid join rules event in database: {e}"); + Error::bad_database("Invalid join rules event in database") }) - .transpose()? - else { + }) else { return Ok(false); }; @@ -201,13 +207,10 @@ pub(crate) fn user_can_perform_restricted_join( None } }) - .any(|m| { - services - .rooms - .state_cache - .is_joined(user_id, &m.room_id) - .unwrap_or(false) - }) { + .stream() + .any(|m| services.rooms.state_cache.is_joined(user_id, &m.room_id)) + .await + { Ok(true) } else { Err(Error::BadRequest( diff --git a/src/api/server/make_leave.rs b/src/api/server/make_leave.rs index 3eb0d77ab..41ea1c80d 100644 --- a/src/api/server/make_leave.rs +++ b/src/api/server/make_leave.rs @@ -18,7 +18,7 @@ use crate::{service::pdu::PduBuilder, Ruma}; pub(crate) async fn create_leave_event_template_route( State(services): State, body: Ruma, ) -> Result { - if !services.rooms.metadata.exists(&body.room_id)? { + if !services.rooms.metadata.exists(&body.room_id).await { return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server.")); } @@ -34,9 +34,10 @@ pub(crate) async fn create_leave_event_template_route( services .rooms .event_handler - .acl_check(origin, &body.room_id)?; + .acl_check(origin, &body.room_id) + .await?; - let room_version_id = services.rooms.state.get_room_version(&body.room_id)?; + let room_version_id = services.rooms.state.get_room_version(&body.room_id).await?; let state_lock = services.rooms.state.mutex.lock(&body.room_id).await; let content = to_raw_value(&RoomMemberEventContent { avatar_url: None, @@ -50,19 +51,23 @@ pub(crate) async fn create_leave_event_template_route( }) .expect("member event is valid value"); - let (_pdu, mut pdu_json) = services.rooms.timeline.create_hash_and_sign_event( - PduBuilder { - event_type: TimelineEventType::RoomMember, - content, - unsigned: None, - state_key: Some(body.user_id.to_string()), - redacts: None, - timestamp: None, - }, - &body.user_id, - &body.room_id, - &state_lock, - )?; + let (_pdu, mut pdu_json) = services + .rooms + .timeline + .create_hash_and_sign_event( + PduBuilder { + event_type: TimelineEventType::RoomMember, + content, + unsigned: None, + state_key: Some(body.user_id.to_string()), + redacts: None, + timestamp: None, + }, + &body.user_id, + &body.room_id, + &state_lock, + ) + .await?; drop(state_lock); diff --git a/src/api/server/openid.rs b/src/api/server/openid.rs index 6a1b99b75..9b54807a6 100644 --- a/src/api/server/openid.rs +++ b/src/api/server/openid.rs @@ -10,6 +10,9 @@ pub(crate) async fn get_openid_userinfo_route( State(services): State, body: Ruma, ) -> Result { Ok(get_openid_userinfo::v1::Response::new( - services.users.find_from_openid_token(&body.access_token)?, + services + .users + .find_from_openid_token(&body.access_token) + .await?, )) } diff --git a/src/api/server/query.rs b/src/api/server/query.rs index c2b78bded..348b8c6e9 100644 --- a/src/api/server/query.rs +++ b/src/api/server/query.rs @@ -1,7 +1,8 @@ use std::collections::BTreeMap; use axum::extract::State; -use conduit::{Error, Result}; +use conduit::{err, Error, Result}; +use futures::StreamExt; use get_profile_information::v1::ProfileField; use rand::seq::SliceRandom; use ruma::{ @@ -23,15 +24,17 @@ pub(crate) async fn get_room_information_route( let room_id = services .rooms .alias - .resolve_local_alias(&body.room_alias)? - .ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Room alias not found."))?; + .resolve_local_alias(&body.room_alias) + .await + .map_err(|_| err!(Request(NotFound("Room alias not found."))))?; let mut servers: Vec = services .rooms .state_cache .room_servers(&room_id) - .filter_map(Result::ok) - .collect(); + .map(ToOwned::to_owned) + .collect() + .await; servers.sort_unstable(); servers.dedup(); @@ -82,30 +85,31 @@ pub(crate) async fn get_profile_information_route( match &body.field { Some(ProfileField::DisplayName) => { - displayname = services.users.displayname(&body.user_id)?; + displayname = services.users.displayname(&body.user_id).await.ok(); }, Some(ProfileField::AvatarUrl) => { - avatar_url = services.users.avatar_url(&body.user_id)?; - blurhash = services.users.blurhash(&body.user_id)?; + avatar_url = services.users.avatar_url(&body.user_id).await.ok(); + blurhash = services.users.blurhash(&body.user_id).await.ok(); }, Some(custom_field) => { - if let Some(value) = services + if let Ok(value) = services .users - .profile_key(&body.user_id, custom_field.as_str())? + .profile_key(&body.user_id, custom_field.as_str()) + .await { custom_profile_fields.insert(custom_field.to_string(), value); } }, None => { - displayname = services.users.displayname(&body.user_id)?; - avatar_url = services.users.avatar_url(&body.user_id)?; - blurhash = services.users.blurhash(&body.user_id)?; - tz = services.users.timezone(&body.user_id)?; + displayname = services.users.displayname(&body.user_id).await.ok(); + avatar_url = services.users.avatar_url(&body.user_id).await.ok(); + blurhash = services.users.blurhash(&body.user_id).await.ok(); + tz = services.users.timezone(&body.user_id).await.ok(); custom_profile_fields = services .users .all_profile_keys(&body.user_id) - .filter_map(Result::ok) - .collect(); + .collect() + .await; }, } diff --git a/src/api/server/send.rs b/src/api/server/send.rs index 15f82faa7..bb4249881 100644 --- a/src/api/server/send.rs +++ b/src/api/server/send.rs @@ -2,7 +2,8 @@ use std::{collections::BTreeMap, net::IpAddr, time::Instant}; use axum::extract::State; use axum_client_ip::InsecureClientIp; -use conduit::{debug, debug_warn, err, trace, warn, Err}; +use conduit::{debug, debug_warn, err, result::LogErr, trace, utils::ReadyExt, warn, Err, Error, Result}; +use futures::StreamExt; use ruma::{ api::{ client::error::ErrorKind, @@ -23,10 +24,13 @@ use tokio::sync::RwLock; use crate::{ services::Services, utils::{self}, - Error, Result, Ruma, + Ruma, }; -type ResolvedMap = BTreeMap>; +const PDU_LIMIT: usize = 50; +const EDU_LIMIT: usize = 100; + +type ResolvedMap = BTreeMap>; /// # `PUT /_matrix/federation/v1/send/{txnId}` /// @@ -44,12 +48,16 @@ pub(crate) async fn send_transaction_message_route( ))); } - if body.pdus.len() > 50_usize { - return Err!(Request(Forbidden("Not allowed to send more than 50 PDUs in one transaction"))); + if body.pdus.len() > PDU_LIMIT { + return Err!(Request(Forbidden( + "Not allowed to send more than {PDU_LIMIT} PDUs in one transaction" + ))); } - if body.edus.len() > 100_usize { - return Err!(Request(Forbidden("Not allowed to send more than 100 EDUs in one transaction"))); + if body.edus.len() > EDU_LIMIT { + return Err!(Request(Forbidden( + "Not allowed to send more than {EDU_LIMIT} EDUs in one transaction" + ))); } let txn_start_time = Instant::now(); @@ -62,8 +70,8 @@ pub(crate) async fn send_transaction_message_route( "Starting txn", ); - let resolved_map = handle_pdus(&services, &client, &body, origin, &txn_start_time).await?; - handle_edus(&services, &client, &body, origin).await?; + let resolved_map = handle_pdus(&services, &client, &body, origin, &txn_start_time).await; + handle_edus(&services, &client, &body, origin).await; debug!( pdus = ?body.pdus.len(), @@ -85,10 +93,10 @@ pub(crate) async fn send_transaction_message_route( async fn handle_pdus( services: &Services, _client: &IpAddr, body: &Ruma, origin: &ServerName, txn_start_time: &Instant, -) -> Result { +) -> ResolvedMap { let mut parsed_pdus = Vec::with_capacity(body.pdus.len()); for pdu in &body.pdus { - parsed_pdus.push(match services.rooms.event_handler.parse_incoming_pdu(pdu) { + parsed_pdus.push(match services.rooms.event_handler.parse_incoming_pdu(pdu).await { Ok(t) => t, Err(e) => { debug_warn!("Could not parse PDU: {e}"); @@ -151,38 +159,34 @@ async fn handle_pdus( } } - Ok(resolved_map) + resolved_map } async fn handle_edus( services: &Services, client: &IpAddr, body: &Ruma, origin: &ServerName, -) -> Result<()> { +) { for edu in body .edus .iter() .filter_map(|edu| serde_json::from_str::(edu.json().get()).ok()) { match edu { - Edu::Presence(presence) => handle_edu_presence(services, client, origin, presence).await?, - Edu::Receipt(receipt) => handle_edu_receipt(services, client, origin, receipt).await?, - Edu::Typing(typing) => handle_edu_typing(services, client, origin, typing).await?, - Edu::DeviceListUpdate(content) => handle_edu_device_list_update(services, client, origin, content).await?, - Edu::DirectToDevice(content) => handle_edu_direct_to_device(services, client, origin, content).await?, - Edu::SigningKeyUpdate(content) => handle_edu_signing_key_update(services, client, origin, content).await?, + Edu::Presence(presence) => handle_edu_presence(services, client, origin, presence).await, + Edu::Receipt(receipt) => handle_edu_receipt(services, client, origin, receipt).await, + Edu::Typing(typing) => handle_edu_typing(services, client, origin, typing).await, + Edu::DeviceListUpdate(content) => handle_edu_device_list_update(services, client, origin, content).await, + Edu::DirectToDevice(content) => handle_edu_direct_to_device(services, client, origin, content).await, + Edu::SigningKeyUpdate(content) => handle_edu_signing_key_update(services, client, origin, content).await, Edu::_Custom(ref _custom) => { debug_warn!(?body.edus, "received custom/unknown EDU"); }, } } - - Ok(()) } -async fn handle_edu_presence( - services: &Services, _client: &IpAddr, origin: &ServerName, presence: PresenceContent, -) -> Result<()> { +async fn handle_edu_presence(services: &Services, _client: &IpAddr, origin: &ServerName, presence: PresenceContent) { if !services.globals.allow_incoming_presence() { - return Ok(()); + return; } for update in presence.push { @@ -194,23 +198,24 @@ async fn handle_edu_presence( continue; } - services.presence.set_presence( - &update.user_id, - &update.presence, - Some(update.currently_active), - Some(update.last_active_ago), - update.status_msg.clone(), - )?; + services + .presence + .set_presence( + &update.user_id, + &update.presence, + Some(update.currently_active), + Some(update.last_active_ago), + update.status_msg.clone(), + ) + .await + .log_err() + .ok(); } - - Ok(()) } -async fn handle_edu_receipt( - services: &Services, _client: &IpAddr, origin: &ServerName, receipt: ReceiptContent, -) -> Result<()> { +async fn handle_edu_receipt(services: &Services, _client: &IpAddr, origin: &ServerName, receipt: ReceiptContent) { if !services.globals.allow_incoming_read_receipts() { - return Ok(()); + return; } for (room_id, room_updates) in receipt.receipts { @@ -218,6 +223,7 @@ async fn handle_edu_receipt( .rooms .event_handler .acl_check(origin, &room_id) + .await .is_err() { debug_warn!( @@ -240,8 +246,8 @@ async fn handle_edu_receipt( .rooms .state_cache .room_members(&room_id) - .filter_map(Result::ok) - .any(|member| member.server_name() == user_id.server_name()) + .ready_any(|member| member.server_name() == user_id.server_name()) + .await { for event_id in &user_updates.event_ids { let user_receipts = BTreeMap::from([(user_id.clone(), user_updates.data.clone())]); @@ -255,7 +261,8 @@ async fn handle_edu_receipt( services .rooms .read_receipt - .readreceipt_update(&user_id, &room_id, &event)?; + .readreceipt_update(&user_id, &room_id, &event) + .await; } } else { debug_warn!( @@ -266,15 +273,11 @@ async fn handle_edu_receipt( } } } - - Ok(()) } -async fn handle_edu_typing( - services: &Services, _client: &IpAddr, origin: &ServerName, typing: TypingContent, -) -> Result<()> { +async fn handle_edu_typing(services: &Services, _client: &IpAddr, origin: &ServerName, typing: TypingContent) { if !services.globals.config.allow_incoming_typing { - return Ok(()); + return; } if typing.user_id.server_name() != origin { @@ -282,26 +285,28 @@ async fn handle_edu_typing( %typing.user_id, %origin, "received typing EDU for user not belonging to origin" ); - return Ok(()); + return; } if services .rooms .event_handler .acl_check(typing.user_id.server_name(), &typing.room_id) + .await .is_err() { debug_warn!( %typing.user_id, %typing.room_id, %origin, "received typing EDU for ACL'd user's server" ); - return Ok(()); + return; } if services .rooms .state_cache - .is_joined(&typing.user_id, &typing.room_id)? + .is_joined(&typing.user_id, &typing.room_id) + .await { if typing.typing { let timeout = utils::millis_since_unix_epoch().saturating_add( @@ -315,28 +320,29 @@ async fn handle_edu_typing( .rooms .typing .typing_add(&typing.user_id, &typing.room_id, timeout) - .await?; + .await + .log_err() + .ok(); } else { services .rooms .typing .typing_remove(&typing.user_id, &typing.room_id) - .await?; + .await + .log_err() + .ok(); } } else { debug_warn!( %typing.user_id, %typing.room_id, %origin, "received typing EDU for user not in room" ); - return Ok(()); } - - Ok(()) } async fn handle_edu_device_list_update( services: &Services, _client: &IpAddr, origin: &ServerName, content: DeviceListUpdateContent, -) -> Result<()> { +) { let DeviceListUpdateContent { user_id, .. @@ -347,17 +353,15 @@ async fn handle_edu_device_list_update( %user_id, %origin, "received device list update EDU for user not belonging to origin" ); - return Ok(()); + return; } - services.users.mark_device_key_update(&user_id)?; - - Ok(()) + services.users.mark_device_key_update(&user_id).await; } async fn handle_edu_direct_to_device( services: &Services, _client: &IpAddr, origin: &ServerName, content: DirectDeviceContent, -) -> Result<()> { +) { let DirectDeviceContent { sender, ev_type, @@ -370,45 +374,52 @@ async fn handle_edu_direct_to_device( %sender, %origin, "received direct to device EDU for user not belonging to origin" ); - return Ok(()); + return; } // Check if this is a new transaction id if services .transaction_ids - .existing_txnid(&sender, None, &message_id)? - .is_some() + .existing_txnid(&sender, None, &message_id) + .await + .is_ok() { - return Ok(()); + return; } for (target_user_id, map) in &messages { for (target_device_id_maybe, event) in map { + let Ok(event) = event + .deserialize_as() + .map_err(|e| err!(Request(InvalidParam(error!("To-Device event is invalid: {e}"))))) + else { + continue; + }; + + let ev_type = ev_type.to_string(); match target_device_id_maybe { DeviceIdOrAllDevices::DeviceId(target_device_id) => { - services.users.add_to_device_event( - &sender, - target_user_id, - target_device_id, - &ev_type.to_string(), - event - .deserialize_as() - .map_err(|e| err!(Request(InvalidParam(error!("To-Device event is invalid: {e}")))))?, - )?; + services + .users + .add_to_device_event(&sender, target_user_id, target_device_id, &ev_type, event) + .await; }, DeviceIdOrAllDevices::AllDevices => { - for target_device_id in services.users.all_device_ids(target_user_id) { - services.users.add_to_device_event( - &sender, - target_user_id, - &target_device_id?, - &ev_type.to_string(), - event - .deserialize_as() - .map_err(|e| err!(Request(InvalidParam("Event is invalid: {e}"))))?, - )?; - } + let (sender, ev_type, event) = (&sender, &ev_type, &event); + services + .users + .all_device_ids(target_user_id) + .for_each(|target_device_id| { + services.users.add_to_device_event( + sender, + target_user_id, + target_device_id, + ev_type, + event.clone(), + ) + }) + .await; }, } } @@ -417,14 +428,12 @@ async fn handle_edu_direct_to_device( // Save transaction id with empty data services .transaction_ids - .add_txnid(&sender, None, &message_id, &[])?; - - Ok(()) + .add_txnid(&sender, None, &message_id, &[]); } async fn handle_edu_signing_key_update( services: &Services, _client: &IpAddr, origin: &ServerName, content: SigningKeyUpdateContent, -) -> Result<()> { +) { let SigningKeyUpdateContent { user_id, master_key, @@ -436,14 +445,15 @@ async fn handle_edu_signing_key_update( %user_id, %origin, "received signing key update EDU from server that does not belong to user's server" ); - return Ok(()); + return; } if let Some(master_key) = master_key { services .users - .add_cross_signing_keys(&user_id, &master_key, &self_signing_key, &None, true)?; + .add_cross_signing_keys(&user_id, &master_key, &self_signing_key, &None, true) + .await + .log_err() + .ok(); } - - Ok(()) } diff --git a/src/api/server/send_join.rs b/src/api/server/send_join.rs index c4d016f61..f92576904 100644 --- a/src/api/server/send_join.rs +++ b/src/api/server/send_join.rs @@ -1,16 +1,17 @@ #![allow(deprecated)] -use std::collections::BTreeMap; +use std::{borrow::Borrow, collections::BTreeMap}; use axum::extract::State; -use conduit::{pdu::gen_event_id_canonical_json, warn, Error, Result}; +use conduit::{err, pdu::gen_event_id_canonical_json, utils::IterStream, warn, Error, Result}; +use futures::{FutureExt, StreamExt, TryStreamExt}; use ruma::{ api::{client::error::ErrorKind, federation::membership::create_join_event}, events::{ room::member::{MembershipState, RoomMemberEventContent}, StateEventType, }, - CanonicalJsonValue, OwnedServerName, OwnedUserId, RoomId, ServerName, + CanonicalJsonValue, EventId, OwnedServerName, OwnedUserId, RoomId, ServerName, }; use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; use service::Services; @@ -22,27 +23,32 @@ use crate::Ruma; async fn create_join_event( services: &Services, origin: &ServerName, room_id: &RoomId, pdu: &RawJsonValue, ) -> Result { - if !services.rooms.metadata.exists(room_id)? { + if !services.rooms.metadata.exists(room_id).await { return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server.")); } // ACL check origin server - services.rooms.event_handler.acl_check(origin, room_id)?; + services + .rooms + .event_handler + .acl_check(origin, room_id) + .await?; // We need to return the state prior to joining, let's keep a reference to that // here let shortstatehash = services .rooms .state - .get_room_shortstatehash(room_id)? - .ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Event state not found."))?; + .get_room_shortstatehash(room_id) + .await + .map_err(|_| err!(Request(NotFound("Event state not found."))))?; let pub_key_map = RwLock::new(BTreeMap::new()); // let mut auth_cache = EventMap::new(); // We do not add the event_id field to the pdu here because of signature and // hashes checks - let room_version_id = services.rooms.state.get_room_version(room_id)?; + let room_version_id = services.rooms.state.get_room_version(room_id).await?; let Ok((event_id, mut value)) = gen_event_id_canonical_json(pdu, &room_version_id) else { // Event could not be converted to canonical json @@ -97,7 +103,8 @@ async fn create_join_event( services .rooms .event_handler - .acl_check(sender.server_name(), room_id)?; + .acl_check(sender.server_name(), room_id) + .await?; // check if origin server is trying to send for another server if sender.server_name() != origin { @@ -126,7 +133,9 @@ async fn create_join_event( if content .join_authorized_via_users_server .is_some_and(|user| services.globals.user_is_local(&user)) - && super::user_can_perform_restricted_join(services, &sender, room_id, &room_version_id).unwrap_or_default() + && super::user_can_perform_restricted_join(services, &sender, room_id, &room_version_id) + .await + .unwrap_or_default() { ruma::signatures::hash_and_sign_event( services.globals.server_name().as_str(), @@ -158,12 +167,14 @@ async fn create_join_event( .mutex_federation .lock(room_id) .await; + let pdu_id: Vec = services .rooms .event_handler .handle_incoming_pdu(&origin, room_id, &event_id, value.clone(), true, &pub_key_map) .await? .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Could not accept as timeline event."))?; + drop(mutex_lock); let state_ids = services @@ -171,29 +182,44 @@ async fn create_join_event( .state_accessor .state_full_ids(shortstatehash) .await?; - let auth_chain_ids = services + + let state = state_ids + .iter() + .try_stream() + .and_then(|(_, event_id)| services.rooms.timeline.get_pdu_json(event_id)) + .and_then(|pdu| { + services + .sending + .convert_to_outgoing_federation_event(pdu) + .map(Ok) + }) + .try_collect() + .await?; + + let starting_events: Vec<&EventId> = state_ids.values().map(Borrow::borrow).collect(); + let auth_chain = services .rooms .auth_chain - .event_ids_iter(room_id, state_ids.values().cloned().collect()) + .event_ids_iter(room_id, &starting_events) + .await? + .map(Ok) + .and_then(|event_id| async move { services.rooms.timeline.get_pdu_json(&event_id).await }) + .and_then(|pdu| { + services + .sending + .convert_to_outgoing_federation_event(pdu) + .map(Ok) + }) + .try_collect() .await?; - services.sending.send_pdu_room(room_id, &pdu_id)?; + services.sending.send_pdu_room(room_id, &pdu_id).await?; Ok(create_join_event::v1::RoomState { - auth_chain: auth_chain_ids - .filter_map(|id| services.rooms.timeline.get_pdu_json(&id).ok().flatten()) - .map(|pdu| services.sending.convert_to_outgoing_federation_event(pdu)) - .collect(), - state: state_ids - .iter() - .filter_map(|(_, id)| services.rooms.timeline.get_pdu_json(id).ok().flatten()) - .map(|pdu| services.sending.convert_to_outgoing_federation_event(pdu)) - .collect(), + auth_chain, + state, // Event field is required if the room version supports restricted join rules. - event: Some( - to_raw_value(&CanonicalJsonValue::Object(value)) - .expect("To raw json should not fail since only change was adding signature"), - ), + event: to_raw_value(&CanonicalJsonValue::Object(value)).ok(), }) } diff --git a/src/api/server/send_leave.rs b/src/api/server/send_leave.rs index e77c5d78a..81f41af07 100644 --- a/src/api/server/send_leave.rs +++ b/src/api/server/send_leave.rs @@ -3,7 +3,7 @@ use std::collections::BTreeMap; use axum::extract::State; -use conduit::{Error, Result}; +use conduit::{utils::ReadyExt, Error, Result}; use ruma::{ api::{client::error::ErrorKind, federation::membership::create_leave_event}, events::{ @@ -49,18 +49,22 @@ pub(crate) async fn create_leave_event_v2_route( async fn create_leave_event( services: &Services, origin: &ServerName, room_id: &RoomId, pdu: &RawJsonValue, ) -> Result<()> { - if !services.rooms.metadata.exists(room_id)? { + if !services.rooms.metadata.exists(room_id).await { return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server.")); } // ACL check origin - services.rooms.event_handler.acl_check(origin, room_id)?; + services + .rooms + .event_handler + .acl_check(origin, room_id) + .await?; let pub_key_map = RwLock::new(BTreeMap::new()); // We do not add the event_id field to the pdu here because of signature and // hashes checks - let room_version_id = services.rooms.state.get_room_version(room_id)?; + let room_version_id = services.rooms.state.get_room_version(room_id).await?; let Ok((event_id, value)) = gen_event_id_canonical_json(pdu, &room_version_id) else { // Event could not be converted to canonical json return Err(Error::BadRequest( @@ -114,7 +118,8 @@ async fn create_leave_event( services .rooms .event_handler - .acl_check(sender.server_name(), room_id)?; + .acl_check(sender.server_name(), room_id) + .await?; if sender.server_name() != origin { return Err(Error::BadRequest( @@ -173,10 +178,9 @@ async fn create_leave_event( .rooms .state_cache .room_servers(room_id) - .filter_map(Result::ok) - .filter(|server| !services.globals.server_is_ours(server)); + .ready_filter(|server| !services.globals.server_is_ours(server)); - services.sending.send_pdu_servers(servers, &pdu_id)?; + services.sending.send_pdu_servers(servers, &pdu_id).await?; Ok(()) } diff --git a/src/api/server/state.rs b/src/api/server/state.rs index d215236af..3a27cd0a3 100644 --- a/src/api/server/state.rs +++ b/src/api/server/state.rs @@ -1,8 +1,9 @@ -use std::sync::Arc; +use std::borrow::Borrow; use axum::extract::State; -use conduit::{Error, Result}; -use ruma::api::{client::error::ErrorKind, federation::event::get_room_state}; +use conduit::{err, result::LogErr, utils::IterStream, Err, Result}; +use futures::{FutureExt, StreamExt, TryStreamExt}; +use ruma::api::federation::event::get_room_state; use crate::Ruma; @@ -17,56 +18,66 @@ pub(crate) async fn get_room_state_route( services .rooms .event_handler - .acl_check(origin, &body.room_id)?; + .acl_check(origin, &body.room_id) + .await?; if !services .rooms .state_accessor - .is_world_readable(&body.room_id)? - && !services - .rooms - .state_cache - .server_in_room(origin, &body.room_id)? + .is_world_readable(&body.room_id) + .await && !services + .rooms + .state_cache + .server_in_room(origin, &body.room_id) + .await { - return Err(Error::BadRequest(ErrorKind::forbidden(), "Server is not in room.")); + return Err!(Request(Forbidden("Server is not in room."))); } let shortstatehash = services .rooms .state_accessor - .pdu_shortstatehash(&body.event_id)? - .ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Pdu state not found."))?; + .pdu_shortstatehash(&body.event_id) + .await + .map_err(|_| err!(Request(NotFound("PDU state not found."))))?; let pdus = services .rooms .state_accessor .state_full_ids(shortstatehash) - .await? - .into_values() - .map(|id| { + .await + .log_err() + .map_err(|_| err!(Request(NotFound("PDU state IDs not found."))))? + .values() + .try_stream() + .and_then(|id| services.rooms.timeline.get_pdu_json(id)) + .and_then(|pdu| { services .sending - .convert_to_outgoing_federation_event(services.rooms.timeline.get_pdu_json(&id).unwrap().unwrap()) + .convert_to_outgoing_federation_event(pdu) + .map(Ok) }) - .collect(); + .try_collect() + .await?; - let auth_chain_ids = services + let auth_chain = services .rooms .auth_chain - .event_ids_iter(&body.room_id, vec![Arc::from(&*body.event_id)]) + .event_ids_iter(&body.room_id, &[body.event_id.borrow()]) + .await? + .map(Ok) + .and_then(|id| async move { services.rooms.timeline.get_pdu_json(&id).await }) + .and_then(|pdu| { + services + .sending + .convert_to_outgoing_federation_event(pdu) + .map(Ok) + }) + .try_collect() .await?; Ok(get_room_state::v1::Response { - auth_chain: auth_chain_ids - .filter_map(|id| { - services - .rooms - .timeline - .get_pdu_json(&id) - .ok()? - .map(|pdu| services.sending.convert_to_outgoing_federation_event(pdu)) - }) - .collect(), + auth_chain, pdus, }) } diff --git a/src/api/server/state_ids.rs b/src/api/server/state_ids.rs index d22f2df4a..b026abf1d 100644 --- a/src/api/server/state_ids.rs +++ b/src/api/server/state_ids.rs @@ -1,9 +1,11 @@ -use std::sync::Arc; +use std::borrow::Borrow; use axum::extract::State; -use ruma::api::{client::error::ErrorKind, federation::event::get_room_state_ids}; +use conduit::{err, Err}; +use futures::StreamExt; +use ruma::api::federation::event::get_room_state_ids; -use crate::{Error, Result, Ruma}; +use crate::{Result, Ruma}; /// # `GET /_matrix/federation/v1/state_ids/{roomId}` /// @@ -17,31 +19,35 @@ pub(crate) async fn get_room_state_ids_route( services .rooms .event_handler - .acl_check(origin, &body.room_id)?; + .acl_check(origin, &body.room_id) + .await?; if !services .rooms .state_accessor - .is_world_readable(&body.room_id)? - && !services - .rooms - .state_cache - .server_in_room(origin, &body.room_id)? + .is_world_readable(&body.room_id) + .await && !services + .rooms + .state_cache + .server_in_room(origin, &body.room_id) + .await { - return Err(Error::BadRequest(ErrorKind::forbidden(), "Server is not in room.")); + return Err!(Request(Forbidden("Server is not in room."))); } let shortstatehash = services .rooms .state_accessor - .pdu_shortstatehash(&body.event_id)? - .ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Pdu state not found."))?; + .pdu_shortstatehash(&body.event_id) + .await + .map_err(|_| err!(Request(NotFound("Pdu state not found."))))?; let pdu_ids = services .rooms .state_accessor .state_full_ids(shortstatehash) - .await? + .await + .map_err(|_| err!(Request(NotFound("State ids not found"))))? .into_values() .map(|id| (*id).to_owned()) .collect(); @@ -49,11 +55,14 @@ pub(crate) async fn get_room_state_ids_route( let auth_chain_ids = services .rooms .auth_chain - .event_ids_iter(&body.room_id, vec![Arc::from(&*body.event_id)]) - .await?; + .event_ids_iter(&body.room_id, &[body.event_id.borrow()]) + .await? + .map(|id| (*id).to_owned()) + .collect() + .await; Ok(get_room_state_ids::v1::Response { - auth_chain_ids: auth_chain_ids.map(|id| (*id).to_owned()).collect(), + auth_chain_ids, pdu_ids, }) } diff --git a/src/api/server/user.rs b/src/api/server/user.rs index e9a400a79..0718da580 100644 --- a/src/api/server/user.rs +++ b/src/api/server/user.rs @@ -1,5 +1,6 @@ use axum::extract::State; use conduit::{Error, Result}; +use futures::{FutureExt, StreamExt, TryFutureExt}; use ruma::api::{ client::error::ErrorKind, federation::{ @@ -28,41 +29,51 @@ pub(crate) async fn get_devices_route( let origin = body.origin.as_ref().expect("server is authenticated"); + let user_id = &body.user_id; Ok(get_devices::v1::Response { - user_id: body.user_id.clone(), + user_id: user_id.clone(), stream_id: services .users - .get_devicelist_version(&body.user_id)? + .get_devicelist_version(user_id) + .await .unwrap_or(0) - .try_into() - .expect("version will not grow that large"), + .try_into()?, devices: services .users - .all_devices_metadata(&body.user_id) - .filter_map(Result::ok) - .filter_map(|metadata| { - let device_id_string = metadata.device_id.as_str().to_owned(); + .all_devices_metadata(user_id) + .filter_map(|metadata| async move { + let device_id = metadata.device_id.clone(); + let device_id_clone = device_id.clone(); + let device_id_string = device_id.as_str().to_owned(); let device_display_name = if services.globals.allow_device_name_federation() { - metadata.display_name + metadata.display_name.clone() } else { Some(device_id_string) }; - Some(UserDevice { - keys: services - .users - .get_device_keys(&body.user_id, &metadata.device_id) - .ok()??, - device_id: metadata.device_id, - device_display_name, - }) + + services + .users + .get_device_keys(user_id, &device_id_clone) + .map_ok(|keys| UserDevice { + device_id, + keys, + device_display_name, + }) + .map(Result::ok) + .await }) - .collect(), + .collect() + .await, master_key: services .users - .get_master_key(None, &body.user_id, &|u| u.server_name() == origin)?, + .get_master_key(None, &body.user_id, &|u| u.server_name() == origin) + .await + .ok(), self_signing_key: services .users - .get_self_signing_key(None, &body.user_id, &|u| u.server_name() == origin)?, + .get_self_signing_key(None, &body.user_id, &|u| u.server_name() == origin) + .await + .ok(), }) } diff --git a/src/core/Cargo.toml b/src/core/Cargo.toml index 713647342..4fe413e93 100644 --- a/src/core/Cargo.toml +++ b/src/core/Cargo.toml @@ -67,6 +67,7 @@ ctor.workspace = true cyborgtime.workspace = true either.workspace = true figment.workspace = true +futures.workspace = true http-body-util.workspace = true http.workspace = true image.workspace = true @@ -82,6 +83,7 @@ ruma.workspace = true sanitize-filename.workspace = true serde_json.workspace = true serde_regex.workspace = true +serde_yaml.workspace = true serde.workspace = true thiserror.workspace = true tikv-jemallocator.optional = true diff --git a/src/core/config/mod.rs b/src/core/config/mod.rs index d2d583a8c..d8e1c7d93 100644 --- a/src/core/config/mod.rs +++ b/src/core/config/mod.rs @@ -236,6 +236,8 @@ pub struct Config { #[serde(default)] pub rocksdb_read_only: bool, #[serde(default)] + pub rocksdb_secondary: bool, + #[serde(default)] pub rocksdb_compaction_prio_idle: bool, #[serde(default = "true_fn")] pub rocksdb_compaction_ioprio_idle: bool, @@ -752,6 +754,7 @@ impl fmt::Display for Config { line("RocksDB Recovery Mode", &self.rocksdb_recovery_mode.to_string()); line("RocksDB Repair Mode", &self.rocksdb_repair.to_string()); line("RocksDB Read-only Mode", &self.rocksdb_read_only.to_string()); + line("RocksDB Secondary Mode", &self.rocksdb_secondary.to_string()); line( "RocksDB Compaction Idle Priority", &self.rocksdb_compaction_prio_idle.to_string(), diff --git a/src/core/debug.rs b/src/core/debug.rs index 844445d53..1e36ca8e2 100644 --- a/src/core/debug.rs +++ b/src/core/debug.rs @@ -1,10 +1,10 @@ use std::{any::Any, panic}; -/// Export debug proc_macros +// Export debug proc_macros pub use conduit_macros::recursion_depth; -/// Export all of the ancillary tools from here as well. -pub use crate::utils::debug::*; +// Export all of the ancillary tools from here as well. +pub use crate::{result::DebugInspect, utils::debug::*}; /// Log event at given level in debug-mode (when debug-assertions are enabled). /// In release-mode it becomes DEBUG level, and possibly subject to elision. diff --git a/src/core/error/err.rs b/src/core/error/err.rs index b3d0240ed..82bb40b05 100644 --- a/src/core/error/err.rs +++ b/src/core/error/err.rs @@ -44,34 +44,34 @@ macro_rules! err { (Request(Forbidden($level:ident!($($args:tt)+)))) => {{ let mut buf = String::new(); $crate::error::Error::Request( - ::ruma::api::client::error::ErrorKind::forbidden(), + $crate::ruma::api::client::error::ErrorKind::forbidden(), $crate::err_log!(buf, $level, $($args)+), - ::http::StatusCode::BAD_REQUEST + $crate::http::StatusCode::BAD_REQUEST ) }}; (Request(Forbidden($($args:tt)+))) => { $crate::error::Error::Request( - ::ruma::api::client::error::ErrorKind::forbidden(), + $crate::ruma::api::client::error::ErrorKind::forbidden(), $crate::format_maybe!($($args)+), - ::http::StatusCode::BAD_REQUEST + $crate::http::StatusCode::BAD_REQUEST ) }; (Request($variant:ident($level:ident!($($args:tt)+)))) => {{ let mut buf = String::new(); $crate::error::Error::Request( - ::ruma::api::client::error::ErrorKind::$variant, + $crate::ruma::api::client::error::ErrorKind::$variant, $crate::err_log!(buf, $level, $($args)+), - ::http::StatusCode::BAD_REQUEST + $crate::http::StatusCode::BAD_REQUEST ) }}; (Request($variant:ident($($args:tt)+))) => { $crate::error::Error::Request( - ::ruma::api::client::error::ErrorKind::$variant, + $crate::ruma::api::client::error::ErrorKind::$variant, $crate::format_maybe!($($args)+), - ::http::StatusCode::BAD_REQUEST + $crate::http::StatusCode::BAD_REQUEST ) }; @@ -85,6 +85,10 @@ macro_rules! err { $crate::error::Error::$variant($crate::err_log!(buf, $level, $($args)+)) }}; + ($variant:ident($($args:ident),+)) => { + $crate::error::Error::$variant($($args),+) + }; + ($variant:ident($($args:tt)+)) => { $crate::error::Error::$variant($crate::format_maybe!($($args)+)) }; @@ -109,7 +113,7 @@ macro_rules! err_log { ($out:ident, $level:ident, $($fields:tt)+) => {{ use std::{fmt, fmt::Write}; - use ::tracing::{ + use $crate::tracing::{ callsite, callsite2, level_enabled, metadata, valueset, Callsite, Event, __macro_support, __tracing_log, field::{Field, ValueSet, Visit}, @@ -165,25 +169,25 @@ macro_rules! err_log { macro_rules! err_lev { (debug_warn) => { if $crate::debug::logging() { - ::tracing::Level::WARN + $crate::tracing::Level::WARN } else { - ::tracing::Level::DEBUG + $crate::tracing::Level::DEBUG } }; (debug_error) => { if $crate::debug::logging() { - ::tracing::Level::ERROR + $crate::tracing::Level::ERROR } else { - ::tracing::Level::DEBUG + $crate::tracing::Level::DEBUG } }; (warn) => { - ::tracing::Level::WARN + $crate::tracing::Level::WARN }; (error) => { - ::tracing::Level::ERROR + $crate::tracing::Level::ERROR }; } diff --git a/src/core/error/log.rs b/src/core/error/log.rs index c272bf730..60bd70140 100644 --- a/src/core/error/log.rs +++ b/src/core/error/log.rs @@ -1,7 +1,8 @@ use std::{convert::Infallible, fmt}; +use tracing::Level; + use super::Error; -use crate::{debug_error, error}; #[inline] pub fn else_log(error: E) -> Result @@ -64,11 +65,33 @@ where } #[inline] -pub fn inspect_log(error: &E) { - error!("{error}"); +pub fn inspect_log(error: &E) { inspect_log_level(error, Level::ERROR); } + +#[inline] +pub fn inspect_debug_log(error: &E) { inspect_debug_log_level(error, Level::ERROR); } + +#[inline] +pub fn inspect_log_level(error: &E, level: Level) { + use crate::{debug, error, info, trace, warn}; + + match level { + Level::ERROR => error!("{error}"), + Level::WARN => warn!("{error}"), + Level::INFO => info!("{error}"), + Level::DEBUG => debug!("{error}"), + Level::TRACE => trace!("{error}"), + } } #[inline] -pub fn inspect_debug_log(error: &E) { - debug_error!("{error:?}"); +pub fn inspect_debug_log_level(error: &E, level: Level) { + use crate::{debug, debug_error, debug_info, debug_warn, trace}; + + match level { + Level::ERROR => debug_error!("{error:?}"), + Level::WARN => debug_warn!("{error:?}"), + Level::INFO => debug_info!("{error:?}"), + Level::DEBUG => debug!("{error:?}"), + Level::TRACE => trace!("{error:?}"), + } } diff --git a/src/core/error/mod.rs b/src/core/error/mod.rs index 92dbdfe3b..ad7f9f3ca 100644 --- a/src/core/error/mod.rs +++ b/src/core/error/mod.rs @@ -75,6 +75,8 @@ pub enum Error { TracingFilter(#[from] tracing_subscriber::filter::ParseError), #[error("Tracing reload error: {0}")] TracingReload(#[from] tracing_subscriber::reload::Error), + #[error(transparent)] + Yaml(#[from] serde_yaml::Error), // ruma/conduwuit #[error("Arithmetic operation failed: {0}")] @@ -86,7 +88,7 @@ pub enum Error { #[error("There was a problem with the '{0}' directive in your configuration: {1}")] Config(&'static str, Cow<'static, str>), #[error("{0}")] - Conflict(&'static str), // This is only needed for when a room alias already exists + Conflict(Cow<'static, str>), // This is only needed for when a room alias already exists #[error(transparent)] ContentDisposition(#[from] ruma::http_headers::ContentDispositionParseError), #[error("{0}")] @@ -107,6 +109,8 @@ pub enum Error { Request(ruma::api::client::error::ErrorKind, Cow<'static, str>, http::StatusCode), #[error(transparent)] Ruma(#[from] ruma::api::client::error::Error), + #[error(transparent)] + StateRes(#[from] ruma::state_res::Error), #[error("uiaa")] Uiaa(ruma::api::client::uiaa::UiaaInfo), @@ -141,19 +145,22 @@ impl Error { use ruma::api::client::error::ErrorKind::Unknown; match self { - Self::Federation(_, error) => response::ruma_error_kind(error).clone(), + Self::Federation(_, error) | Self::Ruma(error) => response::ruma_error_kind(error).clone(), Self::BadRequest(kind, ..) | Self::Request(kind, ..) => kind.clone(), _ => Unknown, } } pub fn status_code(&self) -> http::StatusCode { + use http::StatusCode; + match self { - Self::Federation(_, ref error) | Self::Ruma(ref error) => error.status_code, - Self::Request(ref kind, _, code) => response::status_code(kind, *code), - Self::BadRequest(ref kind, ..) => response::bad_request_code(kind), - Self::Conflict(_) => http::StatusCode::CONFLICT, - _ => http::StatusCode::INTERNAL_SERVER_ERROR, + Self::Federation(_, error) | Self::Ruma(error) => error.status_code, + Self::Request(kind, _, code) => response::status_code(kind, *code), + Self::BadRequest(kind, ..) => response::bad_request_code(kind), + Self::Reqwest(error) => error.status().unwrap_or(StatusCode::INTERNAL_SERVER_ERROR), + Self::Conflict(_) => StatusCode::CONFLICT, + _ => StatusCode::INTERNAL_SERVER_ERROR, } } } @@ -176,3 +183,7 @@ impl From for Error { pub fn infallible(_e: &Infallible) { panic!("infallible error should never exist"); } + +#[inline] +#[must_use] +pub fn is_not_found(e: &Error) -> bool { e.status_code() == http::StatusCode::NOT_FOUND } diff --git a/src/core/mod.rs b/src/core/mod.rs index 9898243bf..e45531864 100644 --- a/src/core/mod.rs +++ b/src/core/mod.rs @@ -7,21 +7,24 @@ pub mod log; pub mod metrics; pub mod mods; pub mod pdu; +pub mod result; pub mod server; pub mod utils; +pub use ::http; +pub use ::ruma; pub use ::toml; +pub use ::tracing; pub use config::Config; pub use error::Error; pub use info::{rustc_flags_capture, version, version::version}; pub use pdu::{PduBuilder, PduCount, PduEvent}; +pub use result::Result; pub use server::Server; pub use utils::{ctor, dtor, implement}; pub use crate as conduit_core; -pub type Result = std::result::Result; - rustc_flags_capture! {} #[cfg(not(conduit_mods))] diff --git a/src/core/pdu/mod.rs b/src/core/pdu/mod.rs index 439c831a5..cf9ffe645 100644 --- a/src/core/pdu/mod.rs +++ b/src/core/pdu/mod.rs @@ -3,8 +3,6 @@ mod count; use std::{cmp::Ordering, collections::BTreeMap, sync::Arc}; -pub use builder::PduBuilder; -pub use count::PduCount; use ruma::{ canonical_json::redact_content_in_place, events::{ @@ -23,7 +21,8 @@ use serde_json::{ value::{to_raw_value, RawValue as RawJsonValue}, }; -use crate::{err, warn, Error}; +pub use self::{builder::PduBuilder, count::PduCount}; +use crate::{err, warn, Error, Result}; #[derive(Deserialize)] struct ExtractRedactedBecause { @@ -65,11 +64,12 @@ pub struct PduEvent { impl PduEvent { #[tracing::instrument(skip(self), level = "debug")] - pub fn redact(&mut self, room_version_id: RoomVersionId, reason: &Self) -> crate::Result<()> { + pub fn redact(&mut self, room_version_id: RoomVersionId, reason: &Self) -> Result<()> { self.unsigned = None; let mut content = serde_json::from_str(self.content.get()) .map_err(|_| Error::bad_database("PDU in db has invalid content."))?; + redact_content_in_place(&mut content, &room_version_id, self.kind.to_string()) .map_err(|e| Error::Redaction(self.sender.server_name().to_owned(), e))?; @@ -98,31 +98,38 @@ impl PduEvent { unsigned.redacted_because.is_some() } - pub fn remove_transaction_id(&mut self) -> crate::Result<()> { - if let Some(unsigned) = &self.unsigned { - let mut unsigned: BTreeMap> = serde_json::from_str(unsigned.get()) - .map_err(|_| Error::bad_database("Invalid unsigned in pdu event"))?; - unsigned.remove("transaction_id"); - self.unsigned = Some(to_raw_value(&unsigned).expect("unsigned is valid")); - } + pub fn remove_transaction_id(&mut self) -> Result<()> { + let Some(unsigned) = &self.unsigned else { + return Ok(()); + }; + + let mut unsigned: BTreeMap> = + serde_json::from_str(unsigned.get()).map_err(|e| err!(Database("Invalid unsigned in pdu event: {e}")))?; + + unsigned.remove("transaction_id"); + self.unsigned = to_raw_value(&unsigned) + .map(Some) + .expect("unsigned is valid"); Ok(()) } - pub fn add_age(&mut self) -> crate::Result<()> { + pub fn add_age(&mut self) -> Result<()> { let mut unsigned: BTreeMap> = self .unsigned .as_ref() .map_or_else(|| Ok(BTreeMap::new()), |u| serde_json::from_str(u.get())) - .map_err(|_| Error::bad_database("Invalid unsigned in pdu event"))?; + .map_err(|e| err!(Database("Invalid unsigned in pdu event: {e}")))?; // deliberately allowing for the possibility of negative age let now: i128 = MilliSecondsSinceUnixEpoch::now().get().into(); let then: i128 = self.origin_server_ts.into(); let this_age = now.saturating_sub(then); - unsigned.insert("age".to_owned(), to_raw_value(&this_age).unwrap()); - self.unsigned = Some(to_raw_value(&unsigned).expect("unsigned is valid")); + unsigned.insert("age".to_owned(), to_raw_value(&this_age).expect("age is valid")); + self.unsigned = to_raw_value(&unsigned) + .map(Some) + .expect("unsigned is valid"); Ok(()) } @@ -369,9 +376,9 @@ impl state_res::Event for PduEvent { fn state_key(&self) -> Option<&str> { self.state_key.as_deref() } - fn prev_events(&self) -> Box + '_> { Box::new(self.prev_events.iter()) } + fn prev_events(&self) -> impl DoubleEndedIterator + Send + '_ { self.prev_events.iter() } - fn auth_events(&self) -> Box + '_> { Box::new(self.auth_events.iter()) } + fn auth_events(&self) -> impl DoubleEndedIterator + Send + '_ { self.auth_events.iter() } fn redacts(&self) -> Option<&Self::Id> { self.redacts.as_ref() } } @@ -395,7 +402,7 @@ impl Ord for PduEvent { /// CanonicalJsonValue>`. pub fn gen_event_id_canonical_json( pdu: &RawJsonValue, room_version_id: &RoomVersionId, -) -> crate::Result<(OwnedEventId, CanonicalJsonObject)> { +) -> Result<(OwnedEventId, CanonicalJsonObject)> { let value: CanonicalJsonObject = serde_json::from_str(pdu.get()) .map_err(|e| err!(BadServerResponse(warn!("Error parsing incoming event: {e:?}"))))?; diff --git a/src/core/result.rs b/src/core/result.rs new file mode 100644 index 000000000..82d67a9c5 --- /dev/null +++ b/src/core/result.rs @@ -0,0 +1,14 @@ +mod debug_inspect; +mod into_is_ok; +mod log_debug_err; +mod log_err; +mod map_expect; +mod not_found; +mod unwrap_infallible; + +pub use self::{ + debug_inspect::DebugInspect, into_is_ok::IntoIsOk, log_debug_err::LogDebugErr, log_err::LogErr, + map_expect::MapExpect, not_found::NotFound, unwrap_infallible::UnwrapInfallible, +}; + +pub type Result = std::result::Result; diff --git a/src/core/result/debug_inspect.rs b/src/core/result/debug_inspect.rs new file mode 100644 index 000000000..ef80979d8 --- /dev/null +++ b/src/core/result/debug_inspect.rs @@ -0,0 +1,52 @@ +use super::Result; + +/// Inspect Result values with release-mode elision. +pub trait DebugInspect { + /// Inspects an Err contained value in debug-mode. In release-mode closure F + /// is elided. + #[must_use] + fn debug_inspect_err(self, f: F) -> Self; + + /// Inspects an Ok contained value in debug-mode. In release-mode closure F + /// is elided. + #[must_use] + fn debug_inspect(self, f: F) -> Self; +} + +#[cfg(debug_assertions)] +impl DebugInspect for Result { + #[inline] + fn debug_inspect(self, f: F) -> Self + where + F: FnOnce(&T), + { + self.inspect(f) + } + + #[inline] + fn debug_inspect_err(self, f: F) -> Self + where + F: FnOnce(&E), + { + self.inspect_err(f) + } +} + +#[cfg(not(debug_assertions))] +impl DebugInspect for Result { + #[inline] + fn debug_inspect(self, _: F) -> Self + where + F: FnOnce(&T), + { + self + } + + #[inline] + fn debug_inspect_err(self, _: F) -> Self + where + F: FnOnce(&E), + { + self + } +} diff --git a/src/core/result/inspect_log.rs b/src/core/result/inspect_log.rs new file mode 100644 index 000000000..577761c5c --- /dev/null +++ b/src/core/result/inspect_log.rs @@ -0,0 +1,60 @@ +use std::fmt; + +use tracing::Level; + +use super::Result; +use crate::error; + +pub trait ErrLog +where + E: fmt::Display, +{ + fn log_err(self, level: Level) -> Self; + + fn err_log(self) -> Self + where + Self: Sized, + { + self.log_err(Level::ERROR) + } +} + +pub trait ErrDebugLog +where + E: fmt::Debug, +{ + fn log_err_debug(self, level: Level) -> Self; + + fn err_debug_log(self) -> Self + where + Self: Sized, + { + self.log_err_debug(Level::ERROR) + } +} + +impl ErrLog for Result +where + E: fmt::Display, +{ + #[inline] + fn log_err(self, level: Level) -> Self + where + Self: Sized, + { + self.inspect_err(|error| error::inspect_log_level(&error, level)) + } +} + +impl ErrDebugLog for Result +where + E: fmt::Debug, +{ + #[inline] + fn log_err_debug(self, level: Level) -> Self + where + Self: Sized, + { + self.inspect_err(|error| error::inspect_debug_log_level(&error, level)) + } +} diff --git a/src/core/result/into_is_ok.rs b/src/core/result/into_is_ok.rs new file mode 100644 index 000000000..220ce010c --- /dev/null +++ b/src/core/result/into_is_ok.rs @@ -0,0 +1,10 @@ +use super::Result; + +pub trait IntoIsOk { + fn into_is_ok(self) -> bool; +} + +impl IntoIsOk for Result { + #[inline] + fn into_is_ok(self) -> bool { self.is_ok() } +} diff --git a/src/core/result/log_debug_err.rs b/src/core/result/log_debug_err.rs new file mode 100644 index 000000000..8835afd19 --- /dev/null +++ b/src/core/result/log_debug_err.rs @@ -0,0 +1,26 @@ +use std::fmt::Debug; + +use tracing::Level; + +use super::{DebugInspect, Result}; +use crate::error; + +pub trait LogDebugErr { + #[must_use] + fn err_debug_log(self, level: Level) -> Self; + + #[must_use] + fn log_debug_err(self) -> Self + where + Self: Sized, + { + self.err_debug_log(Level::ERROR) + } +} + +impl LogDebugErr for Result { + #[inline] + fn err_debug_log(self, level: Level) -> Self { + self.debug_inspect_err(|error| error::inspect_debug_log_level(&error, level)) + } +} diff --git a/src/core/result/log_err.rs b/src/core/result/log_err.rs new file mode 100644 index 000000000..374a5e596 --- /dev/null +++ b/src/core/result/log_err.rs @@ -0,0 +1,24 @@ +use std::fmt::Display; + +use tracing::Level; + +use super::Result; +use crate::error; + +pub trait LogErr { + #[must_use] + fn err_log(self, level: Level) -> Self; + + #[must_use] + fn log_err(self) -> Self + where + Self: Sized, + { + self.err_log(Level::ERROR) + } +} + +impl LogErr for Result { + #[inline] + fn err_log(self, level: Level) -> Self { self.inspect_err(|error| error::inspect_log_level(&error, level)) } +} diff --git a/src/core/result/map_expect.rs b/src/core/result/map_expect.rs new file mode 100644 index 000000000..8ce9195fe --- /dev/null +++ b/src/core/result/map_expect.rs @@ -0,0 +1,15 @@ +use std::fmt::Debug; + +use super::Result; + +pub trait MapExpect { + /// Calls expect(msg) on the mapped Result value. This is similar to + /// map(Result::unwrap) but composes an expect call and message without + /// requiring a closure. + fn map_expect(self, msg: &str) -> Option; +} + +impl MapExpect for Option> { + #[inline] + fn map_expect(self, msg: &str) -> Option { self.map(|result| result.expect(msg)) } +} diff --git a/src/core/result/not_found.rs b/src/core/result/not_found.rs new file mode 100644 index 000000000..69ce821b8 --- /dev/null +++ b/src/core/result/not_found.rs @@ -0,0 +1,12 @@ +use super::Result; +use crate::{error, Error}; + +pub trait NotFound { + #[must_use] + fn is_not_found(&self) -> bool; +} + +impl NotFound for Result { + #[inline] + fn is_not_found(&self) -> bool { self.as_ref().is_err_and(error::is_not_found) } +} diff --git a/src/core/result/unwrap_infallible.rs b/src/core/result/unwrap_infallible.rs new file mode 100644 index 000000000..99309e025 --- /dev/null +++ b/src/core/result/unwrap_infallible.rs @@ -0,0 +1,17 @@ +use std::convert::Infallible; + +use super::{DebugInspect, Result}; +use crate::error; + +pub trait UnwrapInfallible { + fn unwrap_infallible(self) -> T; +} + +impl UnwrapInfallible for Result { + #[inline] + fn unwrap_infallible(self) -> T { + // SAFETY: Branchless unwrap for errors that can never happen. In debug + // mode this is asserted. + unsafe { self.debug_inspect_err(error::infallible).unwrap_unchecked() } + } +} diff --git a/src/core/utils/bool.rs b/src/core/utils/bool.rs new file mode 100644 index 000000000..d7ce78fe3 --- /dev/null +++ b/src/core/utils/bool.rs @@ -0,0 +1,16 @@ +//! Trait BoolExt + +/// Boolean extensions and chain.starters +pub trait BoolExt { + fn or T>(self, f: F) -> Option; + + fn or_some(self, t: T) -> Option; +} + +impl BoolExt for bool { + #[inline] + fn or T>(self, f: F) -> Option { (!self).then(f) } + + #[inline] + fn or_some(self, t: T) -> Option { (!self).then_some(t) } +} diff --git a/src/core/utils/future/mod.rs b/src/core/utils/future/mod.rs new file mode 100644 index 000000000..6d45b6563 --- /dev/null +++ b/src/core/utils/future/mod.rs @@ -0,0 +1,3 @@ +mod try_ext_ext; + +pub use try_ext_ext::TryExtExt; diff --git a/src/core/utils/future/try_ext_ext.rs b/src/core/utils/future/try_ext_ext.rs new file mode 100644 index 000000000..e444ad94a --- /dev/null +++ b/src/core/utils/future/try_ext_ext.rs @@ -0,0 +1,48 @@ +//! Extended external extensions to futures::TryFutureExt + +use futures::{future::MapOkOrElse, TryFuture, TryFutureExt}; + +/// This interface is not necessarily complete; feel free to add as-needed. +pub trait TryExtExt +where + Self: TryFuture + Send, +{ + fn map_ok_or( + self, default: U, f: F, + ) -> MapOkOrElse U, impl FnOnce(Self::Error) -> U> + where + F: FnOnce(Self::Ok) -> U, + Self: Send + Sized; + + fn ok( + self, + ) -> MapOkOrElse Option, impl FnOnce(Self::Error) -> Option> + where + Self: Sized; +} + +impl TryExtExt for Fut +where + Fut: TryFuture + Send, +{ + #[inline] + fn map_ok_or( + self, default: U, f: F, + ) -> MapOkOrElse U, impl FnOnce(Self::Error) -> U> + where + F: FnOnce(Self::Ok) -> U, + Self: Send + Sized, + { + self.map_ok_or_else(|_| default, f) + } + + #[inline] + fn ok( + self, + ) -> MapOkOrElse Option, impl FnOnce(Self::Error) -> Option> + where + Self: Sized, + { + self.map_ok_or(None, Some) + } +} diff --git a/src/core/utils/math.rs b/src/core/utils/math.rs index f9d0de302..215de339c 100644 --- a/src/core/utils/math.rs +++ b/src/core/utils/math.rs @@ -7,32 +7,82 @@ use crate::{debug::type_name, err, Err, Error, Result}; /// Checked arithmetic expression. Returns a Result #[macro_export] macro_rules! checked { - ($($input:tt)*) => { - $crate::utils::math::checked_ops!($($input)*) + ($($input:tt)+) => { + $crate::utils::math::checked_ops!($($input)+) .ok_or_else(|| $crate::err!(Arithmetic("operation overflowed or result invalid"))) - } + }; } -/// in release-mode. Use for performance when the expression is obviously safe. -/// The check remains in debug-mode for regression analysis. +/// Checked arithmetic expression which panics on failure. This is for +/// expressions which do not meet the threshold for validated! but the caller +/// has no realistic expectation for error and no interest in cluttering the +/// callsite with result handling from checked!. +#[macro_export] +macro_rules! expected { + ($msg:literal, $($input:tt)+) => { + $crate::checked!($($input)+).expect($msg) + }; + + ($($input:tt)+) => { + $crate::expected!("arithmetic expression expectation failure", $($input)+) + }; +} + +/// Unchecked arithmetic expression in release-mode. Use for performance when +/// the expression is obviously safe. The check remains in debug-mode for +/// regression analysis. #[cfg(not(debug_assertions))] #[macro_export] macro_rules! validated { - ($($input:tt)*) => { + ($($input:tt)+) => { //#[allow(clippy::arithmetic_side_effects)] { //Some($($input)*) // .ok_or_else(|| $crate::err!(Arithmetic("this error should never been seen"))) //} //NOTE: remove me when stmt_expr_attributes is stable - $crate::checked!($($input)*) - } + $crate::expected!("validated arithmetic expression failed", $($input)+) + }; } +/// Checked arithmetic expression in debug-mode. Use for performance when +/// the expression is obviously safe. The check is elided in release-mode. #[cfg(debug_assertions)] #[macro_export] macro_rules! validated { - ($($input:tt)*) => { $crate::checked!($($input)*) } + ($($input:tt)+) => { $crate::expected!($($input)+) } +} + +/// Functor for equality to zero +#[macro_export] +macro_rules! is_zero { + () => { + $crate::is_matching!(0) + }; +} + +/// Functor for equality i.e. .is_some_and(is_equal!(2)) +#[macro_export] +macro_rules! is_equal_to { + ($val:expr) => { + |x| (x == $val) + }; +} + +/// Functor for less i.e. .is_some_and(is_less_than!(2)) +#[macro_export] +macro_rules! is_less_than { + ($val:expr) => { + |x| (x < $val) + }; +} + +/// Functor for matches! i.e. .is_some_and(is_matching!('A'..='Z')) +#[macro_export] +macro_rules! is_matching { + ($val:expr) => { + |x| matches!(x, $val) + }; } /// Returns false if the exponential backoff has expired based on the inputs @@ -100,3 +150,6 @@ fn try_into_err, Src>(e: >::Error) -> Erro type_name::() )) } + +#[inline] +pub fn clamp(val: T, min: T, max: T) -> T { cmp::min(cmp::max(val, min), max) } diff --git a/src/core/utils/mod.rs b/src/core/utils/mod.rs index 1556646ec..c34691d2d 100644 --- a/src/core/utils/mod.rs +++ b/src/core/utils/mod.rs @@ -1,73 +1,41 @@ +pub mod bool; pub mod bytes; pub mod content_disposition; pub mod debug; pub mod defer; +pub mod future; pub mod hash; pub mod html; pub mod json; pub mod math; pub mod mutex_map; pub mod rand; +pub mod set; +pub mod stream; pub mod string; pub mod sys; mod tests; pub mod time; -use std::cmp::{self, Ordering}; - +pub use ::conduit_macros::implement; pub use ::ctor::{ctor, dtor}; -pub use bytes::{increment, u64_from_bytes, u64_from_u8, u64_from_u8x8}; -pub use conduit_macros::implement; -pub use debug::slice_truncated as debug_slice_truncated; -pub use hash::calculate_hash; -pub use html::Escape as HtmlEscape; -pub use json::{deserialize_from_str, to_canonical_object}; -pub use mutex_map::{Guard as MutexMapGuard, MutexMap}; -pub use rand::string as random_string; -pub use string::{str_from_bytes, string_from_bytes}; -pub use sys::available_parallelism; -pub use time::now_millis as millis_since_unix_epoch; -#[inline] -pub fn clamp(val: T, min: T, max: T) -> T { cmp::min(cmp::max(val, min), max) } +pub use self::{ + bool::BoolExt, + bytes::{increment, u64_from_bytes, u64_from_u8, u64_from_u8x8}, + debug::slice_truncated as debug_slice_truncated, + future::TryExtExt as TryFutureExtExt, + hash::calculate_hash, + html::Escape as HtmlEscape, + json::{deserialize_from_str, to_canonical_object}, + math::clamp, + mutex_map::{Guard as MutexMapGuard, MutexMap}, + rand::string as random_string, + stream::{IterStream, ReadyExt, Tools as StreamTools, TryReadyExt}, + string::{str_from_bytes, string_from_bytes}, + sys::available_parallelism, + time::now_millis as millis_since_unix_epoch, +}; #[inline] -pub fn exchange(state: &mut T, source: T) -> T { - let ret = state.clone(); - *state = source; - ret -} - -#[must_use] -pub fn generate_keypair() -> Vec { - let mut value = rand::string(8).as_bytes().to_vec(); - value.push(0xFF); - value.extend_from_slice( - &ruma::signatures::Ed25519KeyPair::generate().expect("Ed25519KeyPair generation always works (?)"), - ); - value -} - -#[allow(clippy::impl_trait_in_params)] -pub fn common_elements( - mut iterators: impl Iterator>>, check_order: impl Fn(&[u8], &[u8]) -> Ordering, -) -> Option>> { - let first_iterator = iterators.next()?; - let mut other_iterators = iterators.map(Iterator::peekable).collect::>(); - - Some(first_iterator.filter(move |target| { - other_iterators.iter_mut().all(|it| { - while let Some(element) = it.peek() { - match check_order(element, target) { - Ordering::Greater => return false, // We went too far - Ordering::Equal => return true, // Element is in both iters - Ordering::Less => { - // Keep searching - it.next(); - }, - } - } - false - }) - })) -} +pub fn exchange(state: &mut T, source: T) -> T { std::mem::replace(state, source) } diff --git a/src/core/utils/set.rs b/src/core/utils/set.rs new file mode 100644 index 000000000..563f9df5c --- /dev/null +++ b/src/core/utils/set.rs @@ -0,0 +1,47 @@ +use std::cmp::{Eq, Ord}; + +use crate::{is_equal_to, is_less_than}; + +/// Intersection of sets +/// +/// Outputs the set of elements common to all input sets. Inputs do not have to +/// be sorted. If inputs are sorted a more optimized function is available in +/// this suite and should be used. +pub fn intersection(mut input: Iters) -> impl Iterator + Send +where + Iters: Iterator + Clone + Send, + Iter: Iterator + Send, + Item: Eq + Send, +{ + input.next().into_iter().flat_map(move |first| { + let input = input.clone(); + first.filter(move |targ| { + input + .clone() + .all(|mut other| other.any(is_equal_to!(*targ))) + }) + }) +} + +/// Intersection of sets +/// +/// Outputs the set of elements common to all input sets. Inputs must be sorted. +pub fn intersection_sorted(mut input: Iters) -> impl Iterator + Send +where + Iters: Iterator + Clone + Send, + Iter: Iterator + Send, + Item: Eq + Ord + Send, +{ + input.next().into_iter().flat_map(move |first| { + let mut input = input.clone().collect::>(); + first.filter(move |targ| { + input.iter_mut().all(|it| { + it.by_ref() + .skip_while(is_less_than!(targ)) + .peekable() + .peek() + .is_some_and(is_equal_to!(targ)) + }) + }) + }) +} diff --git a/src/core/utils/stream/cloned.rs b/src/core/utils/stream/cloned.rs new file mode 100644 index 000000000..d6a0e6470 --- /dev/null +++ b/src/core/utils/stream/cloned.rs @@ -0,0 +1,20 @@ +use std::clone::Clone; + +use futures::{stream::Map, Stream, StreamExt}; + +pub trait Cloned<'a, T, S> +where + S: Stream, + T: Clone + 'a, +{ + fn cloned(self) -> Map T>; +} + +impl<'a, T, S> Cloned<'a, T, S> for S +where + S: Stream, + T: Clone + 'a, +{ + #[inline] + fn cloned(self) -> Map T> { self.map(Clone::clone) } +} diff --git a/src/core/utils/stream/expect.rs b/src/core/utils/stream/expect.rs new file mode 100644 index 000000000..3ab7181a8 --- /dev/null +++ b/src/core/utils/stream/expect.rs @@ -0,0 +1,17 @@ +use futures::{Stream, StreamExt, TryStream}; + +use crate::Result; + +pub trait TryExpect<'a, Item> { + fn expect_ok(self) -> impl Stream + Send + 'a; +} + +impl<'a, T, Item> TryExpect<'a, Item> for T +where + T: Stream> + TryStream + Send + 'a, +{ + #[inline] + fn expect_ok(self: T) -> impl Stream + Send + 'a { + self.map(|res| res.expect("stream expectation failure")) + } +} diff --git a/src/core/utils/stream/ignore.rs b/src/core/utils/stream/ignore.rs new file mode 100644 index 000000000..997aa4ba4 --- /dev/null +++ b/src/core/utils/stream/ignore.rs @@ -0,0 +1,21 @@ +use futures::{future::ready, Stream, StreamExt, TryStream}; + +use crate::{Error, Result}; + +pub trait TryIgnore<'a, Item> { + fn ignore_err(self) -> impl Stream + Send + 'a; + + fn ignore_ok(self) -> impl Stream + Send + 'a; +} + +impl<'a, T, Item> TryIgnore<'a, Item> for T +where + T: Stream> + TryStream + Send + 'a, + Item: Send + 'a, +{ + #[inline] + fn ignore_err(self: T) -> impl Stream + Send + 'a { self.filter_map(|res| ready(res.ok())) } + + #[inline] + fn ignore_ok(self: T) -> impl Stream + Send + 'a { self.filter_map(|res| ready(res.err())) } +} diff --git a/src/core/utils/stream/iter_stream.rs b/src/core/utils/stream/iter_stream.rs new file mode 100644 index 000000000..69edf64f5 --- /dev/null +++ b/src/core/utils/stream/iter_stream.rs @@ -0,0 +1,27 @@ +use futures::{ + stream, + stream::{Stream, TryStream}, + StreamExt, +}; + +pub trait IterStream { + /// Convert an Iterator into a Stream + fn stream(self) -> impl Stream::Item> + Send; + + /// Convert an Iterator into a TryStream + fn try_stream(self) -> impl TryStream::Item, Error = crate::Error> + Send; +} + +impl IterStream for I +where + I: IntoIterator + Send, + ::IntoIter: Send, +{ + #[inline] + fn stream(self) -> impl Stream::Item> + Send { stream::iter(self) } + + #[inline] + fn try_stream(self) -> impl TryStream::Item, Error = crate::Error> + Send { + self.stream().map(Ok) + } +} diff --git a/src/core/utils/stream/mod.rs b/src/core/utils/stream/mod.rs new file mode 100644 index 000000000..1111915b3 --- /dev/null +++ b/src/core/utils/stream/mod.rs @@ -0,0 +1,15 @@ +mod cloned; +mod expect; +mod ignore; +mod iter_stream; +mod ready; +mod tools; +mod try_ready; + +pub use cloned::Cloned; +pub use expect::TryExpect; +pub use ignore::TryIgnore; +pub use iter_stream::IterStream; +pub use ready::ReadyExt; +pub use tools::Tools; +pub use try_ready::TryReadyExt; diff --git a/src/core/utils/stream/ready.rs b/src/core/utils/stream/ready.rs new file mode 100644 index 000000000..da5aec5a6 --- /dev/null +++ b/src/core/utils/stream/ready.rs @@ -0,0 +1,141 @@ +//! Synchronous combinator extensions to futures::Stream + +use futures::{ + future::{ready, Ready}, + stream::{Any, Filter, FilterMap, Fold, ForEach, Scan, SkipWhile, Stream, StreamExt, TakeWhile}, +}; + +/// Synchronous combinators to augment futures::StreamExt. Most Stream +/// combinators take asynchronous arguments, but often only simple predicates +/// are required to steer a Stream like an Iterator. This suite provides a +/// convenience to reduce boilerplate by de-cluttering non-async predicates. +/// +/// This interface is not necessarily complete; feel free to add as-needed. +pub trait ReadyExt +where + Self: Stream + Send + Sized, +{ + fn ready_any(self, f: F) -> Any, impl FnMut(Item) -> Ready> + where + F: Fn(Item) -> bool; + + fn ready_filter<'a, F>(self, f: F) -> Filter, impl FnMut(&Item) -> Ready + 'a> + where + F: Fn(&Item) -> bool + 'a; + + fn ready_filter_map(self, f: F) -> FilterMap>, impl FnMut(Item) -> Ready>> + where + F: Fn(Item) -> Option; + + fn ready_fold(self, init: T, f: F) -> Fold, T, impl FnMut(T, Item) -> Ready> + where + F: Fn(T, Item) -> T; + + fn ready_for_each(self, f: F) -> ForEach, impl FnMut(Item) -> Ready<()>> + where + F: FnMut(Item); + + fn ready_take_while<'a, F>(self, f: F) -> TakeWhile, impl FnMut(&Item) -> Ready + 'a> + where + F: Fn(&Item) -> bool + 'a; + + fn ready_scan( + self, init: T, f: F, + ) -> Scan>, impl FnMut(&mut T, Item) -> Ready>> + where + F: Fn(&mut T, Item) -> Option; + + fn ready_scan_each( + self, init: T, f: F, + ) -> Scan>, impl FnMut(&mut T, Item) -> Ready>> + where + F: Fn(&mut T, &Item); + + fn ready_skip_while<'a, F>(self, f: F) -> SkipWhile, impl FnMut(&Item) -> Ready + 'a> + where + F: Fn(&Item) -> bool + 'a; +} + +impl ReadyExt for S +where + S: Stream + Send + Sized, +{ + #[inline] + fn ready_any(self, f: F) -> Any, impl FnMut(Item) -> Ready> + where + F: Fn(Item) -> bool, + { + self.any(move |t| ready(f(t))) + } + + #[inline] + fn ready_filter<'a, F>(self, f: F) -> Filter, impl FnMut(&Item) -> Ready + 'a> + where + F: Fn(&Item) -> bool + 'a, + { + self.filter(move |t| ready(f(t))) + } + + #[inline] + fn ready_filter_map(self, f: F) -> FilterMap>, impl FnMut(Item) -> Ready>> + where + F: Fn(Item) -> Option, + { + self.filter_map(move |t| ready(f(t))) + } + + #[inline] + fn ready_fold(self, init: T, f: F) -> Fold, T, impl FnMut(T, Item) -> Ready> + where + F: Fn(T, Item) -> T, + { + self.fold(init, move |a, t| ready(f(a, t))) + } + + #[inline] + #[allow(clippy::unit_arg)] + fn ready_for_each(self, mut f: F) -> ForEach, impl FnMut(Item) -> Ready<()>> + where + F: FnMut(Item), + { + self.for_each(move |t| ready(f(t))) + } + + #[inline] + fn ready_take_while<'a, F>(self, f: F) -> TakeWhile, impl FnMut(&Item) -> Ready + 'a> + where + F: Fn(&Item) -> bool + 'a, + { + self.take_while(move |t| ready(f(t))) + } + + #[inline] + fn ready_scan( + self, init: T, f: F, + ) -> Scan>, impl FnMut(&mut T, Item) -> Ready>> + where + F: Fn(&mut T, Item) -> Option, + { + self.scan(init, move |s, t| ready(f(s, t))) + } + + fn ready_scan_each( + self, init: T, f: F, + ) -> Scan>, impl FnMut(&mut T, Item) -> Ready>> + where + F: Fn(&mut T, &Item), + { + self.ready_scan(init, move |s, t| { + f(s, &t); + Some(t) + }) + } + + #[inline] + fn ready_skip_while<'a, F>(self, f: F) -> SkipWhile, impl FnMut(&Item) -> Ready + 'a> + where + F: Fn(&Item) -> bool + 'a, + { + self.skip_while(move |t| ready(f(t))) + } +} diff --git a/src/core/utils/stream/tools.rs b/src/core/utils/stream/tools.rs new file mode 100644 index 000000000..cc6b7ca9e --- /dev/null +++ b/src/core/utils/stream/tools.rs @@ -0,0 +1,80 @@ +//! StreamTools for futures::Stream + +use std::{collections::HashMap, hash::Hash}; + +use futures::{Future, Stream, StreamExt}; + +use super::ReadyExt; +use crate::expected; + +/// StreamTools +/// +/// This interface is not necessarily complete; feel free to add as-needed. +pub trait Tools +where + Self: Stream + Send + Sized, + ::Item: Send, +{ + fn counts(self) -> impl Future> + Send + where + ::Item: Eq + Hash; + + fn counts_by(self, f: F) -> impl Future> + Send + where + F: Fn(Item) -> K + Send, + K: Eq + Hash + Send; + + fn counts_by_with_cap(self, f: F) -> impl Future> + Send + where + F: Fn(Item) -> K + Send, + K: Eq + Hash + Send; + + fn counts_with_cap(self) -> impl Future> + Send + where + ::Item: Eq + Hash; +} + +impl Tools for S +where + S: Stream + Send + Sized, + ::Item: Send, +{ + #[inline] + fn counts(self) -> impl Future> + Send + where + ::Item: Eq + Hash, + { + self.counts_with_cap::<0>() + } + + #[inline] + fn counts_by(self, f: F) -> impl Future> + Send + where + F: Fn(Item) -> K + Send, + K: Eq + Hash + Send, + { + self.counts_by_with_cap::<0, K, F>(f) + } + + #[inline] + fn counts_by_with_cap(self, f: F) -> impl Future> + Send + where + F: Fn(Item) -> K + Send, + K: Eq + Hash + Send, + { + self.map(f).counts_with_cap::() + } + + #[inline] + fn counts_with_cap(self) -> impl Future> + Send + where + ::Item: Eq + Hash, + { + self.ready_fold(HashMap::with_capacity(CAP), |mut counts, item| { + let entry = counts.entry(item).or_default(); + let value = *entry; + *entry = expected!(value + 1); + counts + }) + } +} diff --git a/src/core/utils/stream/try_ready.rs b/src/core/utils/stream/try_ready.rs new file mode 100644 index 000000000..ab37d9b30 --- /dev/null +++ b/src/core/utils/stream/try_ready.rs @@ -0,0 +1,35 @@ +//! Synchronous combinator extensions to futures::TryStream + +use futures::{ + future::{ready, Ready}, + stream::{AndThen, TryStream, TryStreamExt}, +}; + +use crate::Result; + +/// Synchronous combinators to augment futures::TryStreamExt. +/// +/// This interface is not necessarily complete; feel free to add as-needed. +pub trait TryReadyExt +where + S: TryStream> + Send + ?Sized, + Self: TryStream + Send + Sized, +{ + fn ready_and_then(self, f: F) -> AndThen>, impl FnMut(S::Ok) -> Ready>> + where + F: Fn(S::Ok) -> Result; +} + +impl TryReadyExt for S +where + S: TryStream> + Send + ?Sized, + Self: TryStream + Send + Sized, +{ + #[inline] + fn ready_and_then(self, f: F) -> AndThen>, impl FnMut(S::Ok) -> Ready>> + where + F: Fn(S::Ok) -> Result, + { + self.and_then(move |t| ready(f(t))) + } +} diff --git a/src/core/utils/string.rs b/src/core/utils/string.rs index 85282b30a..e65a33698 100644 --- a/src/core/utils/string.rs +++ b/src/core/utils/string.rs @@ -1,3 +1,10 @@ +mod between; +mod split; +mod tests; +mod unquote; +mod unquoted; + +pub use self::{between::Between, split::SplitInfallible, unquote::Unquote, unquoted::Unquoted}; use crate::{utils::exchange, Result}; pub const EMPTY: &str = ""; @@ -95,12 +102,6 @@ pub fn common_prefix<'a>(choice: &'a [&str]) -> &'a str { }) } -#[inline] -#[must_use] -pub fn split_once_infallible<'a>(input: &'a str, delim: &'_ str) -> (&'a str, &'a str) { - input.split_once(delim).unwrap_or((input, EMPTY)) -} - /// Parses the bytes into a string. pub fn string_from_bytes(bytes: &[u8]) -> Result { let str: &str = str_from_bytes(bytes)?; diff --git a/src/core/utils/string/between.rs b/src/core/utils/string/between.rs new file mode 100644 index 000000000..209a9dabb --- /dev/null +++ b/src/core/utils/string/between.rs @@ -0,0 +1,26 @@ +type Delim<'a> = (&'a str, &'a str); + +/// Slice a string between a pair of delimeters. +pub trait Between<'a> { + /// Extract a string between the delimeters. If the delimeters were not + /// found None is returned, otherwise the first extraction is returned. + fn between(&self, delim: Delim<'_>) -> Option<&'a str>; + + /// Extract a string between the delimeters. If the delimeters were not + /// found the original string is returned; take note of this behavior, + /// if an empty slice is desired for this case use the fallible version and + /// unwrap to EMPTY. + fn between_infallible(&self, delim: Delim<'_>) -> &'a str; +} + +impl<'a> Between<'a> for &'a str { + #[inline] + fn between_infallible(&self, delim: Delim<'_>) -> &'a str { self.between(delim).unwrap_or(self) } + + #[inline] + fn between(&self, delim: Delim<'_>) -> Option<&'a str> { + self.split_once(delim.0) + .and_then(|(_, b)| b.rsplit_once(delim.1)) + .map(|(a, _)| a) + } +} diff --git a/src/core/utils/string/split.rs b/src/core/utils/string/split.rs new file mode 100644 index 000000000..96de28dff --- /dev/null +++ b/src/core/utils/string/split.rs @@ -0,0 +1,22 @@ +use super::EMPTY; + +type Pair<'a> = (&'a str, &'a str); + +/// Split a string with default behaviors on non-match. +pub trait SplitInfallible<'a> { + /// Split a string at the first occurrence of delim. If not found, the + /// entire string is returned in \[0\], while \[1\] is empty. + fn split_once_infallible(&self, delim: &str) -> Pair<'a>; + + /// Split a string from the last occurrence of delim. If not found, the + /// entire string is returned in \[0\], while \[1\] is empty. + fn rsplit_once_infallible(&self, delim: &str) -> Pair<'a>; +} + +impl<'a> SplitInfallible<'a> for &'a str { + #[inline] + fn rsplit_once_infallible(&self, delim: &str) -> Pair<'a> { self.rsplit_once(delim).unwrap_or((self, EMPTY)) } + + #[inline] + fn split_once_infallible(&self, delim: &str) -> Pair<'a> { self.split_once(delim).unwrap_or((self, EMPTY)) } +} diff --git a/src/core/utils/string/tests.rs b/src/core/utils/string/tests.rs new file mode 100644 index 000000000..e8c17de6d --- /dev/null +++ b/src/core/utils/string/tests.rs @@ -0,0 +1,70 @@ +#![cfg(test)] + +#[test] +fn common_prefix() { + let input = ["conduwuit", "conduit", "construct"]; + let output = super::common_prefix(&input); + assert_eq!(output, "con"); +} + +#[test] +fn common_prefix_empty() { + let input = ["abcdefg", "hijklmn", "opqrstu"]; + let output = super::common_prefix(&input); + assert_eq!(output, ""); +} + +#[test] +fn common_prefix_none() { + let input = []; + let output = super::common_prefix(&input); + assert_eq!(output, ""); +} + +#[test] +fn camel_to_snake_case_0() { + let res = super::camel_to_snake_string("CamelToSnakeCase"); + assert_eq!(res, "camel_to_snake_case"); +} + +#[test] +fn camel_to_snake_case_1() { + let res = super::camel_to_snake_string("CAmelTOSnakeCase"); + assert_eq!(res, "camel_tosnake_case"); +} + +#[test] +fn unquote() { + use super::Unquote; + + assert_eq!("\"foo\"".unquote(), Some("foo")); + assert_eq!("\"foo".unquote(), None); + assert_eq!("foo".unquote(), None); +} + +#[test] +fn unquote_infallible() { + use super::Unquote; + + assert_eq!("\"foo\"".unquote_infallible(), "foo"); + assert_eq!("\"foo".unquote_infallible(), "\"foo"); + assert_eq!("foo".unquote_infallible(), "foo"); +} + +#[test] +fn between() { + use super::Between; + + assert_eq!("\"foo\"".between(("\"", "\"")), Some("foo")); + assert_eq!("\"foo".between(("\"", "\"")), None); + assert_eq!("foo".between(("\"", "\"")), None); +} + +#[test] +fn between_infallible() { + use super::Between; + + assert_eq!("\"foo\"".between_infallible(("\"", "\"")), "foo"); + assert_eq!("\"foo".between_infallible(("\"", "\"")), "\"foo"); + assert_eq!("foo".between_infallible(("\"", "\"")), "foo"); +} diff --git a/src/core/utils/string/unquote.rs b/src/core/utils/string/unquote.rs new file mode 100644 index 000000000..eeded610a --- /dev/null +++ b/src/core/utils/string/unquote.rs @@ -0,0 +1,33 @@ +const QUOTE: char = '"'; + +/// Slice a string between quotes +pub trait Unquote<'a> { + /// Whether the input is quoted. If this is false the fallible methods of + /// this interface will fail. + fn is_quoted(&self) -> bool; + + /// Unquotes a string. If the input is not quoted it is simply returned + /// as-is. If the input is partially quoted on either end that quote is not + /// removed. + fn unquote(&self) -> Option<&'a str>; + + /// Unquotes a string. The input must be quoted on each side for Some to be + /// returned + fn unquote_infallible(&self) -> &'a str; +} + +impl<'a> Unquote<'a> for &'a str { + #[inline] + fn unquote_infallible(&self) -> &'a str { + self.strip_prefix(QUOTE) + .unwrap_or(self) + .strip_suffix(QUOTE) + .unwrap_or(self) + } + + #[inline] + fn unquote(&self) -> Option<&'a str> { self.strip_prefix(QUOTE).and_then(|s| s.strip_suffix(QUOTE)) } + + #[inline] + fn is_quoted(&self) -> bool { self.starts_with(QUOTE) && self.ends_with(QUOTE) } +} diff --git a/src/core/utils/string/unquoted.rs b/src/core/utils/string/unquoted.rs new file mode 100644 index 000000000..5b002d99b --- /dev/null +++ b/src/core/utils/string/unquoted.rs @@ -0,0 +1,52 @@ +use std::ops::Deref; + +use serde::{de, Deserialize, Deserializer}; + +use super::Unquote; +use crate::{err, Result}; + +/// Unquoted string which deserialized from a quoted string. Construction from a +/// &str is infallible such that the input can already be unquoted. Construction +/// from serde deserialization is fallible and the input must be quoted. +#[repr(transparent)] +pub struct Unquoted(str); + +impl<'a> Unquoted { + #[inline] + #[must_use] + pub fn as_str(&'a self) -> &'a str { &self.0 } +} + +impl<'a, 'de: 'a> Deserialize<'de> for &'a Unquoted { + fn deserialize>(deserializer: D) -> Result { + let s = <&'a str>::deserialize(deserializer)?; + s.is_quoted() + .then_some(s) + .ok_or(err!(SerdeDe("expected quoted string"))) + .map_err(de::Error::custom) + .map(Into::into) + } +} + +impl<'a> From<&'a str> for &'a Unquoted { + fn from(s: &'a str) -> &'a Unquoted { + let s: &'a str = s.unquote_infallible(); + + //SAFETY: This is a pattern I lifted from ruma-identifiers for strong-type strs + // by wrapping in a tuple-struct. + #[allow(clippy::transmute_ptr_to_ptr)] + unsafe { + std::mem::transmute(s) + } + } +} + +impl Deref for Unquoted { + type Target = str; + + fn deref(&self) -> &Self::Target { &self.0 } +} + +impl<'a> AsRef for &'a Unquoted { + fn as_ref(&self) -> &'a str { &self.0 } +} diff --git a/src/core/utils/tests.rs b/src/core/utils/tests.rs index e91accdf4..84d35936e 100644 --- a/src/core/utils/tests.rs +++ b/src/core/utils/tests.rs @@ -36,33 +36,6 @@ fn increment_wrap() { assert_eq!(res, 0); } -#[test] -fn common_prefix() { - use utils::string; - - let input = ["conduwuit", "conduit", "construct"]; - let output = string::common_prefix(&input); - assert_eq!(output, "con"); -} - -#[test] -fn common_prefix_empty() { - use utils::string; - - let input = ["abcdefg", "hijklmn", "opqrstu"]; - let output = string::common_prefix(&input); - assert_eq!(output, ""); -} - -#[test] -fn common_prefix_none() { - use utils::string; - - let input = []; - let output = string::common_prefix(&input); - assert_eq!(output, ""); -} - #[test] fn checked_add() { use crate::checked; @@ -136,17 +109,131 @@ async fn mutex_map_contend() { } #[test] -fn camel_to_snake_case_0() { - use utils::string::camel_to_snake_string; +#[allow(clippy::iter_on_single_items, clippy::many_single_char_names)] +fn set_intersection_none() { + use utils::set::intersection; + + let a: [&str; 0] = []; + let b: [&str; 0] = []; + let i = [a.iter(), b.iter()]; + let r = intersection(i.into_iter()); + assert_eq!(r.count(), 0); + + let a: [&str; 0] = []; + let b = ["abc", "def"]; + let i = [a.iter(), b.iter()]; + let r = intersection(i.into_iter()); + assert_eq!(r.count(), 0); + let i = [b.iter(), a.iter()]; + let r = intersection(i.into_iter()); + assert_eq!(r.count(), 0); + let i = [a.iter()]; + let r = intersection(i.into_iter()); + assert_eq!(r.count(), 0); + + let a = ["foo", "bar", "baz"]; + let b = ["def", "hij", "klm", "nop"]; + let i = [a.iter(), b.iter()]; + let r = intersection(i.into_iter()); + assert_eq!(r.count(), 0); +} - let res = camel_to_snake_string("CamelToSnakeCase"); - assert_eq!(res, "camel_to_snake_case"); +#[test] +#[allow(clippy::iter_on_single_items, clippy::many_single_char_names)] +fn set_intersection_all() { + use utils::set::intersection; + + let a = ["foo"]; + let b = ["foo"]; + let i = [a.iter(), b.iter()]; + let r = intersection(i.into_iter()); + assert!(r.eq(["foo"].iter())); + + let a = ["foo", "bar"]; + let b = ["bar", "foo"]; + let i = [a.iter(), b.iter()]; + let r = intersection(i.into_iter()); + assert!(r.eq(["foo", "bar"].iter())); + let i = [b.iter()]; + let r = intersection(i.into_iter()); + assert!(r.eq(["bar", "foo"].iter())); + + let a = ["foo", "bar", "baz"]; + let b = ["baz", "foo", "bar"]; + let c = ["bar", "baz", "foo"]; + let i = [a.iter(), b.iter(), c.iter()]; + let r = intersection(i.into_iter()); + assert!(r.eq(["foo", "bar", "baz"].iter())); } #[test] -fn camel_to_snake_case_1() { - use utils::string::camel_to_snake_string; +#[allow(clippy::iter_on_single_items, clippy::many_single_char_names)] +fn set_intersection_some() { + use utils::set::intersection; + + let a = ["foo"]; + let b = ["bar", "foo"]; + let i = [a.iter(), b.iter()]; + let r = intersection(i.into_iter()); + assert!(r.eq(["foo"].iter())); + let i = [b.iter(), a.iter()]; + let r = intersection(i.into_iter()); + assert!(r.eq(["foo"].iter())); + + let a = ["abcdef", "foo", "hijkl", "abc"]; + let b = ["hij", "bar", "baz", "abc", "foo"]; + let c = ["abc", "xyz", "foo", "ghi"]; + let i = [a.iter(), b.iter(), c.iter()]; + let r = intersection(i.into_iter()); + assert!(r.eq(["foo", "abc"].iter())); +} - let res = camel_to_snake_string("CAmelTOSnakeCase"); - assert_eq!(res, "camel_tosnake_case"); +#[test] +#[allow(clippy::iter_on_single_items, clippy::many_single_char_names)] +fn set_intersection_sorted_some() { + use utils::set::intersection_sorted; + + let a = ["bar"]; + let b = ["bar", "foo"]; + let i = [a.iter(), b.iter()]; + let r = intersection_sorted(i.into_iter()); + assert!(r.eq(["bar"].iter())); + let i = [b.iter(), a.iter()]; + let r = intersection_sorted(i.into_iter()); + assert!(r.eq(["bar"].iter())); + + let a = ["aaa", "ccc", "eee", "ggg"]; + let b = ["aaa", "bbb", "ccc", "ddd", "eee"]; + let c = ["bbb", "ccc", "eee", "fff"]; + let i = [a.iter(), b.iter(), c.iter()]; + let r = intersection_sorted(i.into_iter()); + assert!(r.eq(["ccc", "eee"].iter())); +} + +#[test] +#[allow(clippy::iter_on_single_items, clippy::many_single_char_names)] +fn set_intersection_sorted_all() { + use utils::set::intersection_sorted; + + let a = ["foo"]; + let b = ["foo"]; + let i = [a.iter(), b.iter()]; + let r = intersection_sorted(i.into_iter()); + assert!(r.eq(["foo"].iter())); + + let a = ["bar", "foo"]; + let b = ["bar", "foo"]; + let i = [a.iter(), b.iter()]; + let r = intersection_sorted(i.into_iter()); + assert!(r.eq(["bar", "foo"].iter())); + let i = [b.iter()]; + let r = intersection_sorted(i.into_iter()); + assert!(r.eq(["bar", "foo"].iter())); + + let a = ["bar", "baz", "foo"]; + let b = ["bar", "baz", "foo"]; + let c = ["bar", "baz", "foo"]; + let i = [a.iter(), b.iter(), c.iter()]; + let r = intersection_sorted(i.into_iter()); + assert!(r.eq(["bar", "baz", "foo"].iter())); } diff --git a/src/database/Cargo.toml b/src/database/Cargo.toml index 34d98416d..0e718aa71 100644 --- a/src/database/Cargo.toml +++ b/src/database/Cargo.toml @@ -35,10 +35,14 @@ zstd_compression = [ ] [dependencies] +arrayvec.workspace = true conduit-core.workspace = true const-str.workspace = true +futures.workspace = true log.workspace = true rust-rocksdb.workspace = true +serde.workspace = true +serde_json.workspace = true tokio.workspace = true tracing.workspace = true diff --git a/src/database/database.rs b/src/database/database.rs index c357d50f2..4c29c840c 100644 --- a/src/database/database.rs +++ b/src/database/database.rs @@ -37,7 +37,15 @@ impl Database { pub fn cork_and_sync(&self) -> Cork { Cork::new(&self.db, true, true) } #[inline] - pub fn iter_maps(&self) -> impl Iterator + '_ { self.map.iter() } + pub fn iter_maps(&self) -> impl Iterator + Send + '_ { self.map.iter() } + + #[inline] + #[must_use] + pub fn is_read_only(&self) -> bool { self.db.secondary || self.db.read_only } + + #[inline] + #[must_use] + pub fn is_secondary(&self) -> bool { self.db.secondary } } impl Index<&str> for Database { diff --git a/src/database/de.rs b/src/database/de.rs new file mode 100644 index 000000000..fc36560d6 --- /dev/null +++ b/src/database/de.rs @@ -0,0 +1,311 @@ +use conduit::{checked, debug::DebugInspect, err, utils::string, Error, Result}; +use serde::{ + de, + de::{DeserializeSeed, Visitor}, + Deserialize, +}; + +pub(crate) fn from_slice<'a, T>(buf: &'a [u8]) -> Result +where + T: Deserialize<'a>, +{ + let mut deserializer = Deserializer { + buf, + pos: 0, + }; + + T::deserialize(&mut deserializer).debug_inspect(|_| { + deserializer + .finished() + .expect("deserialization failed to consume trailing bytes"); + }) +} + +pub(crate) struct Deserializer<'de> { + buf: &'de [u8], + pos: usize, +} + +/// Directive to ignore a record. This type can be used to skip deserialization +/// until the next separator is found. +#[derive(Debug, Deserialize)] +pub struct Ignore; + +impl<'de> Deserializer<'de> { + const SEP: u8 = b'\xFF'; + + fn finished(&self) -> Result<()> { + let pos = self.pos; + let len = self.buf.len(); + let parsed = &self.buf[0..pos]; + let unparsed = &self.buf[pos..]; + let remain = checked!(len - pos)?; + let trailing_sep = remain == 1 && unparsed[0] == Self::SEP; + (remain == 0 || trailing_sep) + .then_some(()) + .ok_or(err!(SerdeDe( + "{remain} trailing of {len} bytes not deserialized.\n{parsed:?}\n{unparsed:?}", + ))) + } + + #[inline] + fn record_next(&mut self) -> &'de [u8] { + self.buf[self.pos..] + .split(|b| *b == Deserializer::SEP) + .inspect(|record| self.inc_pos(record.len())) + .next() + .expect("remainder of buf even if SEP was not found") + } + + #[inline] + fn record_next_peek_byte(&self) -> Option { + let started = self.pos != 0; + let buf = &self.buf[self.pos..]; + debug_assert!( + !started || buf[0] == Self::SEP, + "Missing expected record separator at current position" + ); + + buf.get::(started.into()).copied() + } + + #[inline] + fn record_start(&mut self) { + let started = self.pos != 0; + debug_assert!( + !started || self.buf[self.pos] == Self::SEP, + "Missing expected record separator at current position" + ); + + self.inc_pos(started.into()); + } + + #[inline] + fn record_trail(&mut self) -> &'de [u8] { + let record = &self.buf[self.pos..]; + self.inc_pos(record.len()); + record + } + + #[inline] + fn inc_pos(&mut self, n: usize) { + self.pos = self.pos.saturating_add(n); + debug_assert!(self.pos <= self.buf.len(), "pos out of range"); + } +} + +impl<'a, 'de: 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { + type Error = Error; + + fn deserialize_seq(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_seq(self) + } + + fn deserialize_tuple(self, _len: usize, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_seq(self) + } + + fn deserialize_tuple_struct(self, _name: &'static str, _len: usize, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_seq(self) + } + + fn deserialize_map(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + let input = self.record_next(); + let mut d = serde_json::Deserializer::from_slice(input); + d.deserialize_map(visitor).map_err(Into::into) + } + + fn deserialize_struct(self, name: &'static str, fields: &'static [&'static str], visitor: V) -> Result + where + V: Visitor<'de>, + { + let input = self.record_next(); + let mut d = serde_json::Deserializer::from_slice(input); + d.deserialize_struct(name, fields, visitor) + .map_err(Into::into) + } + + fn deserialize_unit_struct(self, name: &'static str, visitor: V) -> Result + where + V: Visitor<'de>, + { + match name { + "Ignore" => self.record_next(), + _ => unimplemented!("Unrecognized deserialization Directive {name:?}"), + }; + + visitor.visit_unit() + } + + fn deserialize_newtype_struct(self, name: &'static str, visitor: V) -> Result + where + V: Visitor<'de>, + { + match name { + "$serde_json::private::RawValue" => visitor.visit_map(self), + _ => visitor.visit_newtype_struct(self), + } + } + + fn deserialize_enum( + self, _name: &'static str, _variants: &'static [&'static str], _visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + unimplemented!("deserialize Enum not implemented") + } + + fn deserialize_option>(self, _visitor: V) -> Result { + unimplemented!("deserialize Option not implemented") + } + + fn deserialize_bool>(self, _visitor: V) -> Result { + unimplemented!("deserialize bool not implemented") + } + + fn deserialize_i8>(self, _visitor: V) -> Result { + unimplemented!("deserialize i8 not implemented") + } + + fn deserialize_i16>(self, _visitor: V) -> Result { + unimplemented!("deserialize i16 not implemented") + } + + fn deserialize_i32>(self, _visitor: V) -> Result { + unimplemented!("deserialize i32 not implemented") + } + + fn deserialize_i64>(self, visitor: V) -> Result { + let bytes: [u8; size_of::()] = self.buf[self.pos..].try_into()?; + self.pos = self.pos.saturating_add(size_of::()); + visitor.visit_i64(i64::from_be_bytes(bytes)) + } + + fn deserialize_u8>(self, _visitor: V) -> Result { + unimplemented!("deserialize u8 not implemented; try dereferencing the Handle for [u8] access instead") + } + + fn deserialize_u16>(self, _visitor: V) -> Result { + unimplemented!("deserialize u16 not implemented") + } + + fn deserialize_u32>(self, _visitor: V) -> Result { + unimplemented!("deserialize u32 not implemented") + } + + fn deserialize_u64>(self, visitor: V) -> Result { + let bytes: [u8; size_of::()] = self.buf[self.pos..].try_into()?; + self.pos = self.pos.saturating_add(size_of::()); + visitor.visit_u64(u64::from_be_bytes(bytes)) + } + + fn deserialize_f32>(self, _visitor: V) -> Result { + unimplemented!("deserialize f32 not implemented") + } + + fn deserialize_f64>(self, _visitor: V) -> Result { + unimplemented!("deserialize f64 not implemented") + } + + fn deserialize_char>(self, _visitor: V) -> Result { + unimplemented!("deserialize char not implemented") + } + + fn deserialize_str>(self, visitor: V) -> Result { + let input = self.record_next(); + let out = string::str_from_bytes(input)?; + visitor.visit_borrowed_str(out) + } + + fn deserialize_string>(self, visitor: V) -> Result { + let input = self.record_next(); + let out = string::string_from_bytes(input)?; + visitor.visit_string(out) + } + + fn deserialize_bytes>(self, visitor: V) -> Result { + let input = self.record_trail(); + visitor.visit_borrowed_bytes(input) + } + + fn deserialize_byte_buf>(self, _visitor: V) -> Result { + unimplemented!("deserialize Byte Buf not implemented") + } + + fn deserialize_unit>(self, _visitor: V) -> Result { + unimplemented!("deserialize Unit not implemented") + } + + // this only used for $serde_json::private::RawValue at this time; see MapAccess + fn deserialize_identifier>(self, visitor: V) -> Result { + let input = "$serde_json::private::RawValue"; + visitor.visit_borrowed_str(input) + } + + fn deserialize_ignored_any>(self, _visitor: V) -> Result { + unimplemented!("deserialize Ignored Any not implemented") + } + + fn deserialize_any>(self, visitor: V) -> Result { + debug_assert_eq!( + conduit::debug::type_name::(), + "serde_json::value::de::::deserialize::ValueVisitor", + "deserialize_any: type not expected" + ); + + match self.record_next_peek_byte() { + Some(b'{') => self.deserialize_map(visitor), + _ => self.deserialize_str(visitor), + } + } +} + +impl<'a, 'de: 'a> de::SeqAccess<'de> for &'a mut Deserializer<'de> { + type Error = Error; + + fn next_element_seed(&mut self, seed: T) -> Result> + where + T: DeserializeSeed<'de>, + { + if self.pos >= self.buf.len() { + return Ok(None); + } + + self.record_start(); + seed.deserialize(&mut **self).map(Some) + } +} + +// this only used for $serde_json::private::RawValue at this time. our db +// schema doesn't have its own map format; we use json for that anyway +impl<'a, 'de: 'a> de::MapAccess<'de> for &'a mut Deserializer<'de> { + type Error = Error; + + fn next_key_seed(&mut self, seed: K) -> Result> + where + K: DeserializeSeed<'de>, + { + seed.deserialize(&mut **self).map(Some) + } + + fn next_value_seed(&mut self, seed: V) -> Result + where + V: DeserializeSeed<'de>, + { + seed.deserialize(&mut **self) + } +} diff --git a/src/database/deserialized.rs b/src/database/deserialized.rs new file mode 100644 index 000000000..a59b2ce54 --- /dev/null +++ b/src/database/deserialized.rs @@ -0,0 +1,20 @@ +use std::convert::identity; + +use conduit::Result; +use serde::Deserialize; + +pub trait Deserialized { + fn map_de(self, f: F) -> Result + where + F: FnOnce(T) -> U, + T: for<'de> Deserialize<'de>; + + #[inline] + fn deserialized(self) -> Result + where + T: for<'de> Deserialize<'de>, + Self: Sized, + { + self.map_de(identity::) + } +} diff --git a/src/database/engine.rs b/src/database/engine.rs index 3850c1d3f..99d971ed6 100644 --- a/src/database/engine.rs +++ b/src/database/engine.rs @@ -10,7 +10,7 @@ use conduit::{debug, error, info, utils::time::rfc2822_from_seconds, warn, Err, use rocksdb::{ backup::{BackupEngine, BackupEngineOptions}, perf::get_memory_usage_stats, - AsColumnFamilyRef, BoundColumnFamily, Cache, ColumnFamilyDescriptor, DBCommon, DBWithThreadMode, Env, + AsColumnFamilyRef, BoundColumnFamily, Cache, ColumnFamilyDescriptor, DBCommon, DBWithThreadMode, Env, LogLevel, MultiThreaded, Options, }; @@ -28,6 +28,8 @@ pub struct Engine { cfs: Mutex>, pub(crate) db: Db, corks: AtomicU32, + pub(super) read_only: bool, + pub(super) secondary: bool, } pub(crate) type Db = DBWithThreadMode; @@ -80,10 +82,13 @@ impl Engine { .collect::>(); debug!("Opening database..."); + let path = &config.database_path; let res = if config.rocksdb_read_only { - Db::open_cf_for_read_only(&db_opts, &config.database_path, cfs.clone(), false) + Db::open_cf_descriptors_read_only(&db_opts, path, cfds, false) + } else if config.rocksdb_secondary { + Db::open_cf_descriptors_as_secondary(&db_opts, path, path, cfds) } else { - Db::open_cf_descriptors(&db_opts, &config.database_path, cfds) + Db::open_cf_descriptors(&db_opts, path, cfds) }; let db = res.or_else(or_else)?; @@ -103,10 +108,12 @@ impl Engine { cfs: Mutex::new(cfs), db, corks: AtomicU32::new(0), + read_only: config.rocksdb_read_only, + secondary: config.rocksdb_secondary, })) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "trace")] pub(crate) fn open_cf(&self, name: &str) -> Result>> { let mut cfs = self.cfs.lock().expect("locked"); if !cfs.contains(name) { @@ -279,6 +286,21 @@ pub(crate) fn repair(db_opts: &Options, path: &PathBuf) -> Result<()> { Ok(()) } +#[tracing::instrument(skip_all, name = "rocksdb")] +pub(crate) fn handle_log(level: LogLevel, msg: &str) { + let msg = msg.trim(); + if msg.starts_with("Options") { + return; + } + + match level { + LogLevel::Header | LogLevel::Debug => debug!("{msg}"), + LogLevel::Error | LogLevel::Fatal => error!("{msg}"), + LogLevel::Info => debug!("{msg}"), + LogLevel::Warn => warn!("{msg}"), + }; +} + impl Drop for Engine { #[cold] fn drop(&mut self) { diff --git a/src/database/handle.rs b/src/database/handle.rs index 0b45a75f0..daee224d4 100644 --- a/src/database/handle.rs +++ b/src/database/handle.rs @@ -1,6 +1,10 @@ -use std::ops::Deref; +use std::{fmt, fmt::Debug, ops::Deref}; +use conduit::Result; use rocksdb::DBPinnableSlice; +use serde::{Deserialize, Serialize, Serializer}; + +use crate::{keyval::deserialize_val, Deserialized, Slice}; pub struct Handle<'a> { val: DBPinnableSlice<'a>, @@ -14,14 +18,67 @@ impl<'a> From> for Handle<'a> { } } +impl Debug for Handle<'_> { + fn fmt(&self, out: &mut fmt::Formatter<'_>) -> fmt::Result { + let val: &Slice = self; + let ptr = val.as_ptr(); + let len = val.len(); + write!(out, "Handle {{val: {{ptr: {ptr:?}, len: {len}}}}}") + } +} + +impl Serialize for Handle<'_> { + #[inline] + fn serialize(&self, serializer: S) -> Result { + let bytes: &Slice = self; + serializer.serialize_bytes(bytes) + } +} + +impl Deserialized for Result> { + #[inline] + fn map_de(self, f: F) -> Result + where + F: FnOnce(T) -> U, + T: for<'de> Deserialize<'de>, + { + self?.map_de(f) + } +} + +impl<'a> Deserialized for Result<&'a Handle<'a>> { + #[inline] + fn map_de(self, f: F) -> Result + where + F: FnOnce(T) -> U, + T: for<'de> Deserialize<'de>, + { + self.and_then(|handle| handle.map_de(f)) + } +} + +impl<'a> Deserialized for &'a Handle<'a> { + fn map_de(self, f: F) -> Result + where + F: FnOnce(T) -> U, + T: for<'de> Deserialize<'de>, + { + deserialize_val(self.as_ref()).map(f) + } +} + +impl From> for Vec { + fn from(handle: Handle<'_>) -> Self { handle.deref().to_vec() } +} + impl Deref for Handle<'_> { - type Target = [u8]; + type Target = Slice; #[inline] fn deref(&self) -> &Self::Target { &self.val } } -impl AsRef<[u8]> for Handle<'_> { +impl AsRef for Handle<'_> { #[inline] - fn as_ref(&self) -> &[u8] { &self.val } + fn as_ref(&self) -> &Slice { &self.val } } diff --git a/src/database/iter.rs b/src/database/iter.rs deleted file mode 100644 index 4845e9773..000000000 --- a/src/database/iter.rs +++ /dev/null @@ -1,110 +0,0 @@ -use std::{iter::FusedIterator, sync::Arc}; - -use conduit::Result; -use rocksdb::{ColumnFamily, DBRawIteratorWithThreadMode, Direction, IteratorMode, ReadOptions}; - -use crate::{ - engine::Db, - result, - slice::{OwnedKeyVal, OwnedKeyValPair}, - Engine, -}; - -type Cursor<'cursor> = DBRawIteratorWithThreadMode<'cursor, Db>; - -struct State<'cursor> { - cursor: Cursor<'cursor>, - direction: Direction, - valid: bool, - init: bool, -} - -impl<'cursor> State<'cursor> { - pub(crate) fn new( - db: &'cursor Arc, cf: &'cursor Arc, opts: ReadOptions, mode: &IteratorMode<'_>, - ) -> Self { - let mut cursor = db.db.raw_iterator_cf_opt(&**cf, opts); - let direction = into_direction(mode); - let valid = seek_init(&mut cursor, mode); - Self { - cursor, - direction, - valid, - init: true, - } - } -} - -pub struct Iter<'cursor> { - state: State<'cursor>, -} - -impl<'cursor> Iter<'cursor> { - pub(crate) fn new( - db: &'cursor Arc, cf: &'cursor Arc, opts: ReadOptions, mode: &IteratorMode<'_>, - ) -> Self { - Self { - state: State::new(db, cf, opts, mode), - } - } -} - -impl Iterator for Iter<'_> { - type Item = OwnedKeyValPair; - - fn next(&mut self) -> Option { - if !self.state.init && self.state.valid { - seek_next(&mut self.state.cursor, self.state.direction); - } else if self.state.init { - self.state.init = false; - } - - self.state - .cursor - .item() - .map(OwnedKeyVal::from) - .map(OwnedKeyVal::to_tuple) - .or_else(|| { - when_invalid(&mut self.state).expect("iterator invalidated due to error"); - None - }) - } -} - -impl FusedIterator for Iter<'_> {} - -fn when_invalid(state: &mut State<'_>) -> Result<()> { - state.valid = false; - result(state.cursor.status()) -} - -fn seek_next(cursor: &mut Cursor<'_>, direction: Direction) { - match direction { - Direction::Forward => cursor.next(), - Direction::Reverse => cursor.prev(), - } -} - -fn seek_init(cursor: &mut Cursor<'_>, mode: &IteratorMode<'_>) -> bool { - use Direction::{Forward, Reverse}; - use IteratorMode::{End, From, Start}; - - match mode { - Start => cursor.seek_to_first(), - End => cursor.seek_to_last(), - From(key, Forward) => cursor.seek(key), - From(key, Reverse) => cursor.seek_for_prev(key), - }; - - cursor.valid() -} - -fn into_direction(mode: &IteratorMode<'_>) -> Direction { - use Direction::{Forward, Reverse}; - use IteratorMode::{End, From, Start}; - - match mode { - Start | From(_, Forward) => Forward, - End | From(_, Reverse) => Reverse, - } -} diff --git a/src/database/keyval.rs b/src/database/keyval.rs new file mode 100644 index 000000000..a288f1842 --- /dev/null +++ b/src/database/keyval.rs @@ -0,0 +1,75 @@ +use conduit::Result; +use serde::Deserialize; + +use crate::de; + +pub type KeyVal<'a, K = &'a Slice, V = &'a Slice> = (Key<'a, K>, Val<'a, V>); +pub type Key<'a, T = &'a Slice> = T; +pub type Val<'a, T = &'a Slice> = T; + +pub type Slice = [u8]; + +#[inline] +pub(crate) fn _expect_deserialize<'a, K, V>(kv: Result>) -> KeyVal<'a, K, V> +where + K: Deserialize<'a>, + V: Deserialize<'a>, +{ + result_deserialize(kv).expect("failed to deserialize result key/val") +} + +#[inline] +pub(crate) fn _expect_deserialize_key<'a, K>(key: Result>) -> Key<'a, K> +where + K: Deserialize<'a>, +{ + result_deserialize_key(key).expect("failed to deserialize result key") +} + +#[inline] +pub(crate) fn result_deserialize<'a, K, V>(kv: Result>) -> Result> +where + K: Deserialize<'a>, + V: Deserialize<'a>, +{ + deserialize(kv?) +} + +#[inline] +pub(crate) fn result_deserialize_key<'a, K>(key: Result>) -> Result> +where + K: Deserialize<'a>, +{ + deserialize_key(key?) +} + +#[inline] +pub(crate) fn deserialize<'a, K, V>(kv: KeyVal<'a>) -> Result> +where + K: Deserialize<'a>, + V: Deserialize<'a>, +{ + Ok((deserialize_key::(kv.0)?, deserialize_val::(kv.1)?)) +} + +#[inline] +pub(crate) fn deserialize_key<'a, K>(key: Key<'a>) -> Result> +where + K: Deserialize<'a>, +{ + de::from_slice::(key) +} + +#[inline] +pub(crate) fn deserialize_val<'a, V>(val: Val<'a>) -> Result> +where + V: Deserialize<'a>, +{ + de::from_slice::(val) +} + +#[inline] +pub fn key(kv: KeyVal<'_, K, V>) -> Key<'_, K> { kv.0 } + +#[inline] +pub fn val(kv: KeyVal<'_, K, V>) -> Val<'_, V> { kv.1 } diff --git a/src/database/map.rs b/src/database/map.rs index ddae8c813..cac20d6a6 100644 --- a/src/database/map.rs +++ b/src/database/map.rs @@ -1,16 +1,34 @@ -use std::{ffi::CStr, future::Future, mem::size_of, pin::Pin, sync::Arc}; - -use conduit::{utils, Result}; -use rocksdb::{ - AsColumnFamilyRef, ColumnFamily, Direction, IteratorMode, ReadOptions, WriteBatchWithTransaction, WriteOptions, +mod count; +mod get; +mod insert; +mod keys; +mod keys_from; +mod keys_prefix; +mod remove; +mod rev_keys; +mod rev_keys_from; +mod rev_keys_prefix; +mod rev_stream; +mod rev_stream_from; +mod rev_stream_prefix; +mod stream; +mod stream_from; +mod stream_prefix; + +use std::{ + convert::AsRef, + ffi::CStr, + fmt, + fmt::{Debug, Display}, + future::Future, + pin::Pin, + sync::Arc, }; -use crate::{ - or_else, result, - slice::{Byte, Key, KeyVal, OwnedKey, OwnedKeyValPair, OwnedVal, Val}, - watchers::Watchers, - Engine, Handle, Iter, -}; +use conduit::Result; +use rocksdb::{AsColumnFamilyRef, ColumnFamily, ReadOptions, WriteOptions}; + +use crate::{watchers::Watchers, Engine}; pub struct Map { name: String, @@ -21,8 +39,6 @@ pub struct Map { read_options: ReadOptions, } -type OwnedKeyValPairIter<'a> = Box + Send + 'a>; - impl Map { pub(crate) fn open(db: &Arc, name: &str) -> Result> { Ok(Arc::new(Self { @@ -35,162 +51,18 @@ impl Map { })) } - pub fn get(&self, key: &Key) -> Result>> { - let read_options = &self.read_options; - let res = self.db.db.get_pinned_cf_opt(&self.cf(), key, read_options); - - Ok(result(res)?.map(Handle::from)) - } - - pub fn multi_get(&self, keys: &[&Key]) -> Result>> { - // Optimization can be `true` if key vector is pre-sorted **by the column - // comparator**. - const SORTED: bool = false; - - let mut ret: Vec> = Vec::with_capacity(keys.len()); - let read_options = &self.read_options; - for res in self - .db - .db - .batched_multi_get_cf_opt(&self.cf(), keys, SORTED, read_options) - { - match res { - Ok(Some(res)) => ret.push(Some((*res).to_vec())), - Ok(None) => ret.push(None), - Err(e) => return or_else(e), - } - } - - Ok(ret) - } - - pub fn insert(&self, key: &Key, value: &Val) -> Result<()> { - let write_options = &self.write_options; - self.db - .db - .put_cf_opt(&self.cf(), key, value, write_options) - .or_else(or_else)?; - - if !self.db.corked() { - self.db.flush()?; - } - - self.watchers.wake(key); - - Ok(()) - } - - pub fn insert_batch<'a, I>(&'a self, iter: I) -> Result<()> - where - I: Iterator>, - { - let mut batch = WriteBatchWithTransaction::::default(); - for KeyVal(key, value) in iter { - batch.put_cf(&self.cf(), key, value); - } - - let write_options = &self.write_options; - let res = self.db.db.write_opt(batch, write_options); - - if !self.db.corked() { - self.db.flush()?; - } - - result(res) - } - - pub fn remove(&self, key: &Key) -> Result<()> { - let write_options = &self.write_options; - let res = self.db.db.delete_cf_opt(&self.cf(), key, write_options); - - if !self.db.corked() { - self.db.flush()?; - } - - result(res) - } - - pub fn remove_batch<'a, I>(&'a self, iter: I) -> Result<()> - where - I: Iterator, - { - let mut batch = WriteBatchWithTransaction::::default(); - for key in iter { - batch.delete_cf(&self.cf(), key); - } - - let write_options = &self.write_options; - let res = self.db.db.write_opt(batch, write_options); - - if !self.db.corked() { - self.db.flush()?; - } - - result(res) - } - - pub fn iter(&self) -> OwnedKeyValPairIter<'_> { - let mode = IteratorMode::Start; - let read_options = read_options_default(); - Box::new(Iter::new(&self.db, &self.cf, read_options, &mode)) - } - - pub fn iter_from(&self, from: &Key, reverse: bool) -> OwnedKeyValPairIter<'_> { - let direction = if reverse { - Direction::Reverse - } else { - Direction::Forward - }; - let mode = IteratorMode::From(from, direction); - let read_options = read_options_default(); - Box::new(Iter::new(&self.db, &self.cf, read_options, &mode)) - } - - pub fn scan_prefix(&self, prefix: OwnedKey) -> OwnedKeyValPairIter<'_> { - let mode = IteratorMode::From(&prefix, Direction::Forward); - let read_options = read_options_default(); - Box::new(Iter::new(&self.db, &self.cf, read_options, &mode).take_while(move |(k, _)| k.starts_with(&prefix))) - } - - pub fn increment(&self, key: &Key) -> Result<[Byte; size_of::()]> { - let old = self.get(key)?; - let new = utils::increment(old.as_deref()); - self.insert(key, &new)?; - - if !self.db.corked() { - self.db.flush()?; - } - - Ok(new) - } - - pub fn increment_batch<'a, I>(&'a self, iter: I) -> Result<()> + #[inline] + pub fn watch_prefix<'a, K>(&'a self, prefix: &K) -> Pin + Send + 'a>> where - I: Iterator, + K: AsRef<[u8]> + ?Sized + Debug, { - let mut batch = WriteBatchWithTransaction::::default(); - for key in iter { - let old = self.get(key)?; - let new = utils::increment(old.as_deref()); - batch.put_cf(&self.cf(), key, new); - } - - let write_options = &self.write_options; - let res = self.db.db.write_opt(batch, write_options); - - if !self.db.corked() { - self.db.flush()?; - } - - result(res) - } - - pub fn watch_prefix<'a>(&'a self, prefix: &Key) -> Pin + Send + 'a>> { - self.watchers.watch(prefix) + self.watchers.watch(prefix.as_ref()) } + #[inline] pub fn property_integer(&self, name: &CStr) -> Result { self.db.property_integer(&self.cf(), name) } + #[inline] pub fn property(&self, name: &str) -> Result { self.db.property(&self.cf(), name) } #[inline] @@ -199,12 +71,12 @@ impl Map { fn cf(&self) -> impl AsColumnFamilyRef + '_ { &*self.cf } } -impl<'a> IntoIterator for &'a Map { - type IntoIter = Box + Send + 'a>; - type Item = OwnedKeyValPair; +impl Debug for Map { + fn fmt(&self, out: &mut fmt::Formatter<'_>) -> fmt::Result { write!(out, "Map {{name: {0}}}", self.name) } +} - #[inline] - fn into_iter(self) -> Self::IntoIter { self.iter() } +impl Display for Map { + fn fmt(&self, out: &mut fmt::Formatter<'_>) -> fmt::Result { write!(out, "{0}", self.name) } } fn open(db: &Arc, name: &str) -> Result> { @@ -212,10 +84,7 @@ fn open(db: &Arc, name: &str) -> Result> { let bounded_ptr = Arc::into_raw(bounded_arc); let cf_ptr = bounded_ptr.cast::(); - // SAFETY: After thorough contemplation this appears to be the best solution, - // even by a significant margin. - // - // BACKGROUND: Column family handles out of RocksDB are basic pointers and can + // SAFETY: Column family handles out of RocksDB are basic pointers and can // be invalidated: 1. when the database closes. 2. when the column is dropped or // closed. rust_rocksdb wraps this for us by storing handles in their own // `RwLock` map and returning an Arc>` to diff --git a/src/database/map/count.rs b/src/database/map/count.rs new file mode 100644 index 000000000..4356b71f5 --- /dev/null +++ b/src/database/map/count.rs @@ -0,0 +1,36 @@ +use std::{fmt::Debug, future::Future}; + +use conduit::implement; +use futures::stream::StreamExt; +use serde::Serialize; + +use crate::de::Ignore; + +/// Count the total number of entries in the map. +#[implement(super::Map)] +#[inline] +pub fn count(&self) -> impl Future + Send + '_ { self.keys::().count() } + +/// Count the number of entries in the map starting from a lower-bound. +/// +/// - From is a structured key +#[implement(super::Map)] +#[inline] +pub fn count_from<'a, P>(&'a self, from: &P) -> impl Future + Send + 'a +where + P: Serialize + ?Sized + Debug + 'a, +{ + self.keys_from::(from).count() +} + +/// Count the number of entries in the map matching a prefix. +/// +/// - Prefix is structured key +#[implement(super::Map)] +#[inline] +pub fn count_prefix<'a, P>(&'a self, prefix: &P) -> impl Future + Send + 'a +where + P: Serialize + ?Sized + Debug + 'a, +{ + self.keys_prefix::(prefix).count() +} diff --git a/src/database/map/get.rs b/src/database/map/get.rs new file mode 100644 index 000000000..72382e367 --- /dev/null +++ b/src/database/map/get.rs @@ -0,0 +1,102 @@ +use std::{convert::AsRef, fmt::Debug, future::Future, io::Write}; + +use arrayvec::ArrayVec; +use conduit::{err, implement, Result}; +use futures::future::ready; +use rocksdb::DBPinnableSlice; +use serde::Serialize; + +use crate::{ser, util, Handle}; + +type RocksdbResult<'a> = Result>, rocksdb::Error>; + +/// Fetch a value from the database into cache, returning a reference-handle +/// asynchronously. The key is serialized into an allocated buffer to perform +/// the query. +#[implement(super::Map)] +pub fn qry(&self, key: &K) -> impl Future>> + Send +where + K: Serialize + ?Sized + Debug, +{ + let mut buf = Vec::::with_capacity(64); + self.bqry(key, &mut buf) +} + +/// Fetch a value from the database into cache, returning a reference-handle +/// asynchronously. The key is serialized into a fixed-sized buffer to perform +/// the query. The maximum size is supplied as const generic parameter. +#[implement(super::Map)] +pub fn aqry(&self, key: &K) -> impl Future>> + Send +where + K: Serialize + ?Sized + Debug, +{ + let mut buf = ArrayVec::::new(); + self.bqry(key, &mut buf) +} + +/// Fetch a value from the database into cache, returning a reference-handle +/// asynchronously. The key is serialized into a user-supplied Writer. +#[implement(super::Map)] +#[tracing::instrument(skip(self, buf), fields(%self), level = "trace")] +pub fn bqry(&self, key: &K, buf: &mut B) -> impl Future>> + Send +where + K: Serialize + ?Sized + Debug, + B: Write + AsRef<[u8]>, +{ + let key = ser::serialize(buf, key).expect("failed to serialize query key"); + self.get(key) +} + +/// Fetch a value from the database into cache, returning a reference-handle +/// asynchronously. The key is referenced directly to perform the query. +#[implement(super::Map)] +pub fn get(&self, key: &K) -> impl Future>> + Send +where + K: AsRef<[u8]> + ?Sized + Debug, +{ + ready(self.get_blocking(key)) +} + +/// Fetch a value from the database into cache, returning a reference-handle. +/// The key is referenced directly to perform the query. This is a thread- +/// blocking call. +#[implement(super::Map)] +#[tracing::instrument(skip(self, key), fields(%self), level = "trace")] +pub fn get_blocking(&self, key: &K) -> Result> +where + K: AsRef<[u8]> + ?Sized + Debug, +{ + let res = self + .db + .db + .get_pinned_cf_opt(&self.cf(), key, &self.read_options); + + into_result_handle(res) +} + +#[implement(super::Map)] +#[tracing::instrument(skip(self, keys), fields(%self), level = "trace")] +pub fn get_batch_blocking<'a, I, K>(&self, keys: I) -> Vec>> +where + I: Iterator + ExactSizeIterator + Send + Debug, + K: AsRef<[u8]> + Sized + Debug + 'a, +{ + // Optimization can be `true` if key vector is pre-sorted **by the column + // comparator**. + const SORTED: bool = false; + + let read_options = &self.read_options; + self.db + .db + .batched_multi_get_cf_opt(&self.cf(), keys, SORTED, read_options) + .into_iter() + .map(into_result_handle) + .collect() +} + +fn into_result_handle(result: RocksdbResult<'_>) -> Result> { + result + .map_err(util::map_err)? + .map(Handle::from) + .ok_or(err!(Request(NotFound("Not found in database")))) +} diff --git a/src/database/map/insert.rs b/src/database/map/insert.rs new file mode 100644 index 000000000..953c9c94c --- /dev/null +++ b/src/database/map/insert.rs @@ -0,0 +1,52 @@ +use std::{convert::AsRef, fmt::Debug}; + +use conduit::implement; +use rocksdb::WriteBatchWithTransaction; + +use crate::util::or_else; + +#[implement(super::Map)] +#[tracing::instrument(skip(self, value), fields(%self), level = "trace")] +pub fn insert(&self, key: &K, value: &V) +where + K: AsRef<[u8]> + ?Sized + Debug, + V: AsRef<[u8]> + ?Sized, +{ + let write_options = &self.write_options; + self.db + .db + .put_cf_opt(&self.cf(), key, value, write_options) + .or_else(or_else) + .expect("database insert error"); + + if !self.db.corked() { + self.db.flush().expect("database flush error"); + } + + self.watchers.wake(key.as_ref()); +} + +#[implement(super::Map)] +#[tracing::instrument(skip(self, iter), fields(%self), level = "trace")] +pub fn insert_batch<'a, I, K, V>(&'a self, iter: I) +where + I: Iterator + Send + Debug, + K: AsRef<[u8]> + Sized + Debug + 'a, + V: AsRef<[u8]> + Sized + 'a, +{ + let mut batch = WriteBatchWithTransaction::::default(); + for (key, val) in iter { + batch.put_cf(&self.cf(), key.as_ref(), val.as_ref()); + } + + let write_options = &self.write_options; + self.db + .db + .write_opt(batch, write_options) + .or_else(or_else) + .expect("database insert batch error"); + + if !self.db.corked() { + self.db.flush().expect("database flush error"); + } +} diff --git a/src/database/map/keys.rs b/src/database/map/keys.rs new file mode 100644 index 000000000..2396494c4 --- /dev/null +++ b/src/database/map/keys.rs @@ -0,0 +1,21 @@ +use conduit::{implement, Result}; +use futures::{Stream, StreamExt}; +use serde::Deserialize; + +use crate::{keyval, keyval::Key, stream}; + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn keys<'a, K>(&'a self) -> impl Stream>> + Send +where + K: Deserialize<'a> + Send, +{ + self.raw_keys().map(keyval::result_deserialize_key::) +} + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn raw_keys(&self) -> impl Stream>> + Send { + let opts = super::read_options_default(); + stream::Keys::new(&self.db, &self.cf, opts, None) +} diff --git a/src/database/map/keys_from.rs b/src/database/map/keys_from.rs new file mode 100644 index 000000000..1993750ab --- /dev/null +++ b/src/database/map/keys_from.rs @@ -0,0 +1,49 @@ +use std::{convert::AsRef, fmt::Debug}; + +use conduit::{implement, Result}; +use futures::{Stream, StreamExt}; +use serde::{Deserialize, Serialize}; + +use crate::{keyval, keyval::Key, ser, stream}; + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn keys_from<'a, K, P>(&'a self, from: &P) -> impl Stream>> + Send +where + P: Serialize + ?Sized + Debug, + K: Deserialize<'a> + Send, +{ + self.keys_raw_from(from) + .map(keyval::result_deserialize_key::) +} + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn keys_raw_from

    (&self, from: &P) -> impl Stream>> + Send +where + P: Serialize + ?Sized + Debug, +{ + let key = ser::serialize_to_vec(from).expect("failed to serialize query key"); + self.raw_keys_from(&key) +} + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn keys_from_raw<'a, K, P>(&'a self, from: &P) -> impl Stream>> + Send +where + P: AsRef<[u8]> + ?Sized + Debug + Sync, + K: Deserialize<'a> + Send, +{ + self.raw_keys_from(from) + .map(keyval::result_deserialize_key::) +} + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn raw_keys_from

    (&self, from: &P) -> impl Stream>> + Send +where + P: AsRef<[u8]> + ?Sized + Debug, +{ + let opts = super::read_options_default(); + stream::Keys::new(&self.db, &self.cf, opts, Some(from.as_ref())) +} diff --git a/src/database/map/keys_prefix.rs b/src/database/map/keys_prefix.rs new file mode 100644 index 000000000..d6c0927b9 --- /dev/null +++ b/src/database/map/keys_prefix.rs @@ -0,0 +1,54 @@ +use std::{convert::AsRef, fmt::Debug}; + +use conduit::{implement, Result}; +use futures::{ + future, + stream::{Stream, StreamExt}, + TryStreamExt, +}; +use serde::{Deserialize, Serialize}; + +use crate::{keyval, keyval::Key, ser}; + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn keys_prefix<'a, K, P>(&'a self, prefix: &P) -> impl Stream>> + Send +where + P: Serialize + ?Sized + Debug, + K: Deserialize<'a> + Send, +{ + self.keys_raw_prefix(prefix) + .map(keyval::result_deserialize_key::) +} + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn keys_raw_prefix

    (&self, prefix: &P) -> impl Stream>> + Send +where + P: Serialize + ?Sized + Debug, +{ + let key = ser::serialize_to_vec(prefix).expect("failed to serialize query key"); + self.raw_keys_from(&key) + .try_take_while(move |k: &Key<'_>| future::ok(k.starts_with(&key))) +} + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn keys_prefix_raw<'a, K, P>(&'a self, prefix: &'a P) -> impl Stream>> + Send + 'a +where + P: AsRef<[u8]> + ?Sized + Debug + Sync + 'a, + K: Deserialize<'a> + Send + 'a, +{ + self.raw_keys_prefix(prefix) + .map(keyval::result_deserialize_key::) +} + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn raw_keys_prefix<'a, P>(&'a self, prefix: &'a P) -> impl Stream>> + Send + 'a +where + P: AsRef<[u8]> + ?Sized + Debug + Sync + 'a, +{ + self.raw_keys_from(prefix) + .try_take_while(|k: &Key<'_>| future::ok(k.starts_with(prefix.as_ref()))) +} diff --git a/src/database/map/remove.rs b/src/database/map/remove.rs new file mode 100644 index 000000000..10bb2ff01 --- /dev/null +++ b/src/database/map/remove.rs @@ -0,0 +1,54 @@ +use std::{convert::AsRef, fmt::Debug, io::Write}; + +use arrayvec::ArrayVec; +use conduit::implement; +use serde::Serialize; + +use crate::{ser, util::or_else}; + +#[implement(super::Map)] +pub fn del(&self, key: &K) +where + K: Serialize + ?Sized + Debug, +{ + let mut buf = Vec::::with_capacity(64); + self.bdel(key, &mut buf); +} + +#[implement(super::Map)] +pub fn adel(&self, key: &K) +where + K: Serialize + ?Sized + Debug, +{ + let mut buf = ArrayVec::::new(); + self.bdel(key, &mut buf); +} + +#[implement(super::Map)] +#[tracing::instrument(skip(self, buf), fields(%self), level = "trace")] +pub fn bdel(&self, key: &K, buf: &mut B) +where + K: Serialize + ?Sized + Debug, + B: Write + AsRef<[u8]>, +{ + let key = ser::serialize(buf, key).expect("failed to serialize deletion key"); + self.remove(key); +} + +#[implement(super::Map)] +#[tracing::instrument(skip(self, key), fields(%self), level = "trace")] +pub fn remove(&self, key: &K) +where + K: AsRef<[u8]> + ?Sized + Debug, +{ + let write_options = &self.write_options; + self.db + .db + .delete_cf_opt(&self.cf(), key, write_options) + .or_else(or_else) + .expect("database remove error"); + + if !self.db.corked() { + self.db.flush().expect("database flush error"); + } +} diff --git a/src/database/map/rev_keys.rs b/src/database/map/rev_keys.rs new file mode 100644 index 000000000..449ccfff3 --- /dev/null +++ b/src/database/map/rev_keys.rs @@ -0,0 +1,21 @@ +use conduit::{implement, Result}; +use futures::{Stream, StreamExt}; +use serde::Deserialize; + +use crate::{keyval, keyval::Key, stream}; + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_keys<'a, K>(&'a self) -> impl Stream>> + Send +where + K: Deserialize<'a> + Send, +{ + self.rev_raw_keys().map(keyval::result_deserialize_key::) +} + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_raw_keys(&self) -> impl Stream>> + Send { + let opts = super::read_options_default(); + stream::KeysRev::new(&self.db, &self.cf, opts, None) +} diff --git a/src/database/map/rev_keys_from.rs b/src/database/map/rev_keys_from.rs new file mode 100644 index 000000000..e012e60af --- /dev/null +++ b/src/database/map/rev_keys_from.rs @@ -0,0 +1,49 @@ +use std::{convert::AsRef, fmt::Debug}; + +use conduit::{implement, Result}; +use futures::{Stream, StreamExt}; +use serde::{Deserialize, Serialize}; + +use crate::{keyval, keyval::Key, ser, stream}; + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_keys_from<'a, K, P>(&'a self, from: &P) -> impl Stream>> + Send +where + P: Serialize + ?Sized + Debug, + K: Deserialize<'a> + Send, +{ + self.rev_keys_raw_from(from) + .map(keyval::result_deserialize_key::) +} + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_keys_raw_from

    (&self, from: &P) -> impl Stream>> + Send +where + P: Serialize + ?Sized + Debug, +{ + let key = ser::serialize_to_vec(from).expect("failed to serialize query key"); + self.rev_raw_keys_from(&key) +} + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_keys_from_raw<'a, K, P>(&'a self, from: &P) -> impl Stream>> + Send +where + P: AsRef<[u8]> + ?Sized + Debug + Sync, + K: Deserialize<'a> + Send, +{ + self.rev_raw_keys_from(from) + .map(keyval::result_deserialize_key::) +} + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_raw_keys_from

    (&self, from: &P) -> impl Stream>> + Send +where + P: AsRef<[u8]> + ?Sized + Debug, +{ + let opts = super::read_options_default(); + stream::KeysRev::new(&self.db, &self.cf, opts, Some(from.as_ref())) +} diff --git a/src/database/map/rev_keys_prefix.rs b/src/database/map/rev_keys_prefix.rs new file mode 100644 index 000000000..162c4f9b8 --- /dev/null +++ b/src/database/map/rev_keys_prefix.rs @@ -0,0 +1,54 @@ +use std::{convert::AsRef, fmt::Debug}; + +use conduit::{implement, Result}; +use futures::{ + future, + stream::{Stream, StreamExt}, + TryStreamExt, +}; +use serde::{Deserialize, Serialize}; + +use crate::{keyval, keyval::Key, ser}; + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_keys_prefix<'a, K, P>(&'a self, prefix: &P) -> impl Stream>> + Send +where + P: Serialize + ?Sized + Debug, + K: Deserialize<'a> + Send, +{ + self.rev_keys_raw_prefix(prefix) + .map(keyval::result_deserialize_key::) +} + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_keys_raw_prefix

    (&self, prefix: &P) -> impl Stream>> + Send +where + P: Serialize + ?Sized + Debug, +{ + let key = ser::serialize_to_vec(prefix).expect("failed to serialize query key"); + self.rev_raw_keys_from(&key) + .try_take_while(move |k: &Key<'_>| future::ok(k.starts_with(&key))) +} + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_keys_prefix_raw<'a, K, P>(&'a self, prefix: &'a P) -> impl Stream>> + Send + 'a +where + P: AsRef<[u8]> + ?Sized + Debug + Sync + 'a, + K: Deserialize<'a> + Send + 'a, +{ + self.rev_raw_keys_prefix(prefix) + .map(keyval::result_deserialize_key::) +} + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_raw_keys_prefix<'a, P>(&'a self, prefix: &'a P) -> impl Stream>> + Send + 'a +where + P: AsRef<[u8]> + ?Sized + Debug + Sync + 'a, +{ + self.rev_raw_keys_from(prefix) + .try_take_while(|k: &Key<'_>| future::ok(k.starts_with(prefix.as_ref()))) +} diff --git a/src/database/map/rev_stream.rs b/src/database/map/rev_stream.rs new file mode 100644 index 000000000..de22fd5ce --- /dev/null +++ b/src/database/map/rev_stream.rs @@ -0,0 +1,29 @@ +use conduit::{implement, Result}; +use futures::stream::{Stream, StreamExt}; +use serde::Deserialize; + +use crate::{keyval, keyval::KeyVal, stream}; + +/// Iterate key-value entries in the map from the end. +/// +/// - Result is deserialized +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_stream<'a, K, V>(&'a self) -> impl Stream>> + Send +where + K: Deserialize<'a> + Send, + V: Deserialize<'a> + Send, +{ + self.rev_raw_stream() + .map(keyval::result_deserialize::) +} + +/// Iterate key-value entries in the map from the end. +/// +/// - Result is raw +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_raw_stream(&self) -> impl Stream>> + Send { + let opts = super::read_options_default(); + stream::ItemsRev::new(&self.db, &self.cf, opts, None) +} diff --git a/src/database/map/rev_stream_from.rs b/src/database/map/rev_stream_from.rs new file mode 100644 index 000000000..650cf038c --- /dev/null +++ b/src/database/map/rev_stream_from.rs @@ -0,0 +1,68 @@ +use std::{convert::AsRef, fmt::Debug}; + +use conduit::{implement, Result}; +use futures::stream::{Stream, StreamExt}; +use serde::{Deserialize, Serialize}; + +use crate::{keyval, keyval::KeyVal, ser, stream}; + +/// Iterate key-value entries in the map starting from upper-bound. +/// +/// - Query is serialized +/// - Result is deserialized +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_stream_from<'a, K, V, P>(&'a self, from: &P) -> impl Stream>> + Send +where + P: Serialize + ?Sized + Debug, + K: Deserialize<'a> + Send, + V: Deserialize<'a> + Send, +{ + let key = ser::serialize_to_vec(from).expect("failed to serialize query key"); + self.rev_stream_raw_from(&key) + .map(keyval::result_deserialize::) +} + +/// Iterate key-value entries in the map starting from upper-bound. +/// +/// - Query is serialized +/// - Result is raw +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_stream_raw_from

    (&self, from: &P) -> impl Stream>> + Send +where + P: Serialize + ?Sized + Debug, +{ + let key = ser::serialize_to_vec(from).expect("failed to serialize query key"); + self.rev_raw_stream_from(&key) +} + +/// Iterate key-value entries in the map starting from upper-bound. +/// +/// - Query is raw +/// - Result is deserialized +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_stream_from_raw<'a, K, V, P>(&'a self, from: &P) -> impl Stream>> + Send +where + P: AsRef<[u8]> + ?Sized + Debug + Sync, + K: Deserialize<'a> + Send, + V: Deserialize<'a> + Send, +{ + self.rev_raw_stream_from(from) + .map(keyval::result_deserialize::) +} + +/// Iterate key-value entries in the map starting from upper-bound. +/// +/// - Query is raw +/// - Result is raw +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_raw_stream_from

    (&self, from: &P) -> impl Stream>> + Send +where + P: AsRef<[u8]> + ?Sized + Debug, +{ + let opts = super::read_options_default(); + stream::ItemsRev::new(&self.db, &self.cf, opts, Some(from.as_ref())) +} diff --git a/src/database/map/rev_stream_prefix.rs b/src/database/map/rev_stream_prefix.rs new file mode 100644 index 000000000..9ef89e9cb --- /dev/null +++ b/src/database/map/rev_stream_prefix.rs @@ -0,0 +1,74 @@ +use std::{convert::AsRef, fmt::Debug}; + +use conduit::{implement, Result}; +use futures::{ + future, + stream::{Stream, StreamExt}, + TryStreamExt, +}; +use serde::{Deserialize, Serialize}; + +use crate::{keyval, keyval::KeyVal, ser}; + +/// Iterate key-value entries in the map where the key matches a prefix. +/// +/// - Query is serialized +/// - Result is deserialized +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_stream_prefix<'a, K, V, P>(&'a self, prefix: &P) -> impl Stream>> + Send +where + P: Serialize + ?Sized + Debug, + K: Deserialize<'a> + Send, + V: Deserialize<'a> + Send, +{ + self.rev_stream_raw_prefix(prefix) + .map(keyval::result_deserialize::) +} + +/// Iterate key-value entries in the map where the key matches a prefix. +/// +/// - Query is serialized +/// - Result is raw +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_stream_raw_prefix

    (&self, prefix: &P) -> impl Stream>> + Send +where + P: Serialize + ?Sized + Debug, +{ + let key = ser::serialize_to_vec(prefix).expect("failed to serialize query key"); + self.rev_raw_stream_from(&key) + .try_take_while(move |(k, _): &KeyVal<'_>| future::ok(k.starts_with(&key))) +} + +/// Iterate key-value entries in the map where the key matches a prefix. +/// +/// - Query is raw +/// - Result is deserialized +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_stream_prefix_raw<'a, K, V, P>( + &'a self, prefix: &'a P, +) -> impl Stream>> + Send + 'a +where + P: AsRef<[u8]> + ?Sized + Debug + Sync + 'a, + K: Deserialize<'a> + Send + 'a, + V: Deserialize<'a> + Send + 'a, +{ + self.rev_raw_stream_prefix(prefix) + .map(keyval::result_deserialize::) +} + +/// Iterate key-value entries in the map where the key matches a prefix. +/// +/// - Query is raw +/// - Result is raw +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_raw_stream_prefix<'a, P>(&'a self, prefix: &'a P) -> impl Stream>> + Send + 'a +where + P: AsRef<[u8]> + ?Sized + Debug + Sync + 'a, +{ + self.rev_raw_stream_from(prefix) + .try_take_while(|(k, _): &KeyVal<'_>| future::ok(k.starts_with(prefix.as_ref()))) +} diff --git a/src/database/map/stream.rs b/src/database/map/stream.rs new file mode 100644 index 000000000..dfbea0729 --- /dev/null +++ b/src/database/map/stream.rs @@ -0,0 +1,28 @@ +use conduit::{implement, Result}; +use futures::stream::{Stream, StreamExt}; +use serde::Deserialize; + +use crate::{keyval, keyval::KeyVal, stream}; + +/// Iterate key-value entries in the map from the beginning. +/// +/// - Result is deserialized +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn stream<'a, K, V>(&'a self) -> impl Stream>> + Send +where + K: Deserialize<'a> + Send, + V: Deserialize<'a> + Send, +{ + self.raw_stream().map(keyval::result_deserialize::) +} + +/// Iterate key-value entries in the map from the beginning. +/// +/// - Result is raw +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn raw_stream(&self) -> impl Stream>> + Send { + let opts = super::read_options_default(); + stream::Items::new(&self.db, &self.cf, opts, None) +} diff --git a/src/database/map/stream_from.rs b/src/database/map/stream_from.rs new file mode 100644 index 000000000..153d5bb61 --- /dev/null +++ b/src/database/map/stream_from.rs @@ -0,0 +1,68 @@ +use std::{convert::AsRef, fmt::Debug}; + +use conduit::{implement, Result}; +use futures::stream::{Stream, StreamExt}; +use serde::{Deserialize, Serialize}; + +use crate::{keyval, keyval::KeyVal, ser, stream}; + +/// Iterate key-value entries in the map starting from lower-bound. +/// +/// - Query is serialized +/// - Result is deserialized +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn stream_from<'a, K, V, P>(&'a self, from: &P) -> impl Stream>> + Send +where + P: Serialize + ?Sized + Debug, + K: Deserialize<'a> + Send, + V: Deserialize<'a> + Send, +{ + let key = ser::serialize_to_vec(from).expect("failed to serialize query key"); + self.stream_raw_from(&key) + .map(keyval::result_deserialize::) +} + +/// Iterate key-value entries in the map starting from lower-bound. +/// +/// - Query is serialized +/// - Result is raw +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn stream_raw_from

    (&self, from: &P) -> impl Stream>> + Send +where + P: Serialize + ?Sized + Debug, +{ + let key = ser::serialize_to_vec(from).expect("failed to serialize query key"); + self.raw_stream_from(&key) +} + +/// Iterate key-value entries in the map starting from lower-bound. +/// +/// - Query is raw +/// - Result is deserialized +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn stream_from_raw<'a, K, V, P>(&'a self, from: &P) -> impl Stream>> + Send +where + P: AsRef<[u8]> + ?Sized + Debug + Sync, + K: Deserialize<'a> + Send, + V: Deserialize<'a> + Send, +{ + self.raw_stream_from(from) + .map(keyval::result_deserialize::) +} + +/// Iterate key-value entries in the map starting from lower-bound. +/// +/// - Query is raw +/// - Result is raw +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn raw_stream_from

    (&self, from: &P) -> impl Stream>> + Send +where + P: AsRef<[u8]> + ?Sized + Debug, +{ + let opts = super::read_options_default(); + stream::Items::new(&self.db, &self.cf, opts, Some(from.as_ref())) +} diff --git a/src/database/map/stream_prefix.rs b/src/database/map/stream_prefix.rs new file mode 100644 index 000000000..56154a8b3 --- /dev/null +++ b/src/database/map/stream_prefix.rs @@ -0,0 +1,74 @@ +use std::{convert::AsRef, fmt::Debug}; + +use conduit::{implement, Result}; +use futures::{ + future, + stream::{Stream, StreamExt}, + TryStreamExt, +}; +use serde::{Deserialize, Serialize}; + +use crate::{keyval, keyval::KeyVal, ser}; + +/// Iterate key-value entries in the map where the key matches a prefix. +/// +/// - Query is serialized +/// - Result is deserialized +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn stream_prefix<'a, K, V, P>(&'a self, prefix: &P) -> impl Stream>> + Send +where + P: Serialize + ?Sized + Debug, + K: Deserialize<'a> + Send, + V: Deserialize<'a> + Send, +{ + self.stream_raw_prefix(prefix) + .map(keyval::result_deserialize::) +} + +/// Iterate key-value entries in the map where the key matches a prefix. +/// +/// - Query is serialized +/// - Result is raw +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn stream_raw_prefix

    (&self, prefix: &P) -> impl Stream>> + Send +where + P: Serialize + ?Sized + Debug, +{ + let key = ser::serialize_to_vec(prefix).expect("failed to serialize query key"); + self.raw_stream_from(&key) + .try_take_while(move |(k, _): &KeyVal<'_>| future::ok(k.starts_with(&key))) +} + +/// Iterate key-value entries in the map where the key matches a prefix. +/// +/// - Query is raw +/// - Result is deserialized +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn stream_prefix_raw<'a, K, V, P>( + &'a self, prefix: &'a P, +) -> impl Stream>> + Send + 'a +where + P: AsRef<[u8]> + ?Sized + Debug + Sync + 'a, + K: Deserialize<'a> + Send + 'a, + V: Deserialize<'a> + Send + 'a, +{ + self.raw_stream_prefix(prefix) + .map(keyval::result_deserialize::) +} + +/// Iterate key-value entries in the map where the key matches a prefix. +/// +/// - Query is raw +/// - Result is raw +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn raw_stream_prefix<'a, P>(&'a self, prefix: &'a P) -> impl Stream>> + Send + 'a +where + P: AsRef<[u8]> + ?Sized + Debug + Sync + 'a, +{ + self.raw_stream_from(prefix) + .try_take_while(|(k, _): &KeyVal<'_>| future::ok(k.starts_with(prefix.as_ref()))) +} diff --git a/src/database/mod.rs b/src/database/mod.rs index 6446624ca..e66abf682 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -1,25 +1,35 @@ mod cork; mod database; +mod de; +mod deserialized; mod engine; mod handle; -mod iter; +pub mod keyval; mod map; pub mod maps; mod opts; -mod slice; +mod ser; +mod stream; mod util; mod watchers; +pub(crate) use self::{ + engine::Engine, + util::{or_else, result}, +}; + extern crate conduit_core as conduit; extern crate rust_rocksdb as rocksdb; -pub use database::Database; -pub(crate) use engine::Engine; -pub use handle::Handle; -pub use iter::Iter; -pub use map::Map; -pub use slice::{Key, KeyVal, OwnedKey, OwnedKeyVal, OwnedVal, Val}; -pub(crate) use util::{or_else, result}; +pub use self::{ + database::Database, + de::Ignore, + deserialized::Deserialized, + handle::Handle, + keyval::{KeyVal, Slice}, + map::Map, + ser::{Interfix, Separator}, +}; conduit::mod_ctor! {} conduit::mod_dtor! {} diff --git a/src/database/opts.rs b/src/database/opts.rs index d2ad4b95c..46fb4c542 100644 --- a/src/database/opts.rs +++ b/src/database/opts.rs @@ -191,6 +191,8 @@ fn set_logging_defaults(opts: &mut Options, config: &Config) { if config.rocksdb_log_stderr { opts.set_stderr_logger(rocksdb_log_level, "rocksdb"); + } else { + opts.set_callback_logger(rocksdb_log_level, &super::engine::handle_log); } } diff --git a/src/database/ser.rs b/src/database/ser.rs new file mode 100644 index 000000000..bd4bbd9ad --- /dev/null +++ b/src/database/ser.rs @@ -0,0 +1,315 @@ +use std::io::Write; + +use conduit::{err, result::DebugInspect, utils::exchange, Error, Result}; +use serde::{ser, Serialize}; + +#[inline] +pub(crate) fn serialize_to_vec(val: &T) -> Result> +where + T: Serialize + ?Sized, +{ + let mut buf = Vec::with_capacity(64); + serialize(&mut buf, val)?; + + Ok(buf) +} + +#[inline] +pub(crate) fn serialize<'a, W, T>(out: &'a mut W, val: &'a T) -> Result<&'a [u8]> +where + W: Write + AsRef<[u8]>, + T: Serialize + ?Sized, +{ + let mut serializer = Serializer { + out, + depth: 0, + sep: false, + fin: false, + }; + + val.serialize(&mut serializer) + .map_err(|error| err!(SerdeSer("{error}"))) + .debug_inspect(|()| { + debug_assert_eq!(serializer.depth, 0, "Serialization completed at non-zero recursion level"); + })?; + + Ok((*out).as_ref()) +} + +pub(crate) struct Serializer<'a, W: Write> { + out: &'a mut W, + depth: u32, + sep: bool, + fin: bool, +} + +/// Directive to force separator serialization specifically for prefix keying +/// use. This is a quirk of the database schema and prefix iterations. +#[derive(Debug, Serialize)] +pub struct Interfix; + +/// Directive to force separator serialization. Separators are usually +/// serialized automatically. +#[derive(Debug, Serialize)] +pub struct Separator; + +impl Serializer<'_, W> { + const SEP: &'static [u8] = b"\xFF"; + + fn sequence_start(&mut self) { + debug_assert!(!self.is_finalized(), "Sequence start with finalization set"); + debug_assert!(!self.sep, "Sequence start with separator set"); + if cfg!(debug_assertions) { + self.depth = self.depth.saturating_add(1); + } + } + + fn sequence_end(&mut self) { + self.sep = false; + if cfg!(debug_assertions) { + self.depth = self.depth.saturating_sub(1); + } + } + + fn record_start(&mut self) -> Result<()> { + debug_assert!(!self.is_finalized(), "Starting a record after serialization finalized"); + exchange(&mut self.sep, true) + .then(|| self.separator()) + .unwrap_or(Ok(())) + } + + fn separator(&mut self) -> Result<()> { + debug_assert!(!self.is_finalized(), "Writing a separator after serialization finalized"); + self.out.write_all(Self::SEP).map_err(Into::into) + } + + fn set_finalized(&mut self) { + debug_assert!(!self.is_finalized(), "Finalization already set"); + if cfg!(debug_assertions) { + self.fin = true; + } + } + + fn is_finalized(&self) -> bool { self.fin } +} + +impl ser::Serializer for &mut Serializer<'_, W> { + type Error = Error; + type Ok = (); + type SerializeMap = Self; + type SerializeSeq = Self; + type SerializeStruct = Self; + type SerializeStructVariant = Self; + type SerializeTuple = Self; + type SerializeTupleStruct = Self; + type SerializeTupleVariant = Self; + + fn serialize_map(self, _len: Option) -> Result { + unimplemented!("serialize Map not implemented") + } + + fn serialize_seq(self, _len: Option) -> Result { + self.sequence_start(); + self.record_start()?; + Ok(self) + } + + fn serialize_tuple(self, _len: usize) -> Result { + self.sequence_start(); + Ok(self) + } + + fn serialize_tuple_struct(self, _name: &'static str, _len: usize) -> Result { + self.sequence_start(); + Ok(self) + } + + fn serialize_tuple_variant( + self, _name: &'static str, _idx: u32, _var: &'static str, _len: usize, + ) -> Result { + self.sequence_start(); + Ok(self) + } + + fn serialize_struct(self, _name: &'static str, _len: usize) -> Result { + self.sequence_start(); + Ok(self) + } + + fn serialize_struct_variant( + self, _name: &'static str, _idx: u32, _var: &'static str, _len: usize, + ) -> Result { + self.sequence_start(); + Ok(self) + } + + fn serialize_newtype_struct(self, _name: &'static str, _value: &T) -> Result { + unimplemented!("serialize New Type Struct not implemented") + } + + fn serialize_newtype_variant( + self, _name: &'static str, _idx: u32, _var: &'static str, _value: &T, + ) -> Result { + unimplemented!("serialize New Type Variant not implemented") + } + + fn serialize_unit_struct(self, name: &'static str) -> Result { + match name { + "Interfix" => { + self.set_finalized(); + }, + "Separator" => { + self.separator()?; + }, + _ => unimplemented!("Unrecognized serialization directive: {name:?}"), + }; + + Ok(()) + } + + fn serialize_unit_variant(self, _name: &'static str, _idx: u32, _var: &'static str) -> Result { + unimplemented!("serialize Unit Variant not implemented") + } + + fn serialize_some(self, val: &T) -> Result { val.serialize(self) } + + fn serialize_none(self) -> Result { Ok(()) } + + fn serialize_char(self, v: char) -> Result { + let mut buf: [u8; 4] = [0; 4]; + self.serialize_str(v.encode_utf8(&mut buf)) + } + + fn serialize_str(self, v: &str) -> Result { self.serialize_bytes(v.as_bytes()) } + + fn serialize_bytes(self, v: &[u8]) -> Result { self.out.write_all(v).map_err(Error::Io) } + + fn serialize_f64(self, _v: f64) -> Result { unimplemented!("serialize f64 not implemented") } + + fn serialize_f32(self, _v: f32) -> Result { unimplemented!("serialize f32 not implemented") } + + fn serialize_i64(self, v: i64) -> Result { self.out.write_all(&v.to_be_bytes()).map_err(Error::Io) } + + fn serialize_i32(self, _v: i32) -> Result { unimplemented!("serialize i32 not implemented") } + + fn serialize_i16(self, _v: i16) -> Result { unimplemented!("serialize i16 not implemented") } + + fn serialize_i8(self, _v: i8) -> Result { unimplemented!("serialize i8 not implemented") } + + fn serialize_u64(self, v: u64) -> Result { self.out.write_all(&v.to_be_bytes()).map_err(Error::Io) } + + fn serialize_u32(self, _v: u32) -> Result { unimplemented!("serialize u32 not implemented") } + + fn serialize_u16(self, _v: u16) -> Result { unimplemented!("serialize u16 not implemented") } + + fn serialize_u8(self, v: u8) -> Result { self.out.write_all(&[v]).map_err(Error::Io) } + + fn serialize_bool(self, _v: bool) -> Result { unimplemented!("serialize bool not implemented") } + + fn serialize_unit(self) -> Result { unimplemented!("serialize unit not implemented") } +} + +impl ser::SerializeMap for &mut Serializer<'_, W> { + type Error = Error; + type Ok = (); + + fn serialize_key(&mut self, _key: &T) -> Result { + unimplemented!("serialize Map Key not implemented") + } + + fn serialize_value(&mut self, _val: &T) -> Result { + unimplemented!("serialize Map Val not implemented") + } + + fn end(self) -> Result { + self.sequence_end(); + Ok(()) + } +} + +impl ser::SerializeSeq for &mut Serializer<'_, W> { + type Error = Error; + type Ok = (); + + fn serialize_element(&mut self, val: &T) -> Result { val.serialize(&mut **self) } + + fn end(self) -> Result { + self.sequence_end(); + Ok(()) + } +} + +impl ser::SerializeStruct for &mut Serializer<'_, W> { + type Error = Error; + type Ok = (); + + fn serialize_field(&mut self, _key: &'static str, val: &T) -> Result { + self.record_start()?; + val.serialize(&mut **self) + } + + fn end(self) -> Result { + self.sequence_end(); + Ok(()) + } +} + +impl ser::SerializeStructVariant for &mut Serializer<'_, W> { + type Error = Error; + type Ok = (); + + fn serialize_field(&mut self, _key: &'static str, val: &T) -> Result { + self.record_start()?; + val.serialize(&mut **self) + } + + fn end(self) -> Result { + self.sequence_end(); + Ok(()) + } +} + +impl ser::SerializeTuple for &mut Serializer<'_, W> { + type Error = Error; + type Ok = (); + + fn serialize_element(&mut self, val: &T) -> Result { + self.record_start()?; + val.serialize(&mut **self) + } + + fn end(self) -> Result { + self.sequence_end(); + Ok(()) + } +} + +impl ser::SerializeTupleStruct for &mut Serializer<'_, W> { + type Error = Error; + type Ok = (); + + fn serialize_field(&mut self, val: &T) -> Result { + self.record_start()?; + val.serialize(&mut **self) + } + + fn end(self) -> Result { + self.sequence_end(); + Ok(()) + } +} + +impl ser::SerializeTupleVariant for &mut Serializer<'_, W> { + type Error = Error; + type Ok = (); + + fn serialize_field(&mut self, val: &T) -> Result { + self.record_start()?; + val.serialize(&mut **self) + } + + fn end(self) -> Result { + self.sequence_end(); + Ok(()) + } +} diff --git a/src/database/slice.rs b/src/database/slice.rs deleted file mode 100644 index 448d969d9..000000000 --- a/src/database/slice.rs +++ /dev/null @@ -1,57 +0,0 @@ -pub struct OwnedKeyVal(pub OwnedKey, pub OwnedVal); -pub(crate) type OwnedKeyValPair = (OwnedKey, OwnedVal); -pub type OwnedVal = Vec; -pub type OwnedKey = Vec; - -pub struct KeyVal<'item>(pub &'item Key, pub &'item Val); -pub(crate) type KeyValPair<'item> = (&'item Key, &'item Val); -pub type Val = [Byte]; -pub type Key = [Byte]; - -pub(crate) type Byte = u8; - -impl OwnedKeyVal { - #[must_use] - pub fn as_slice(&self) -> KeyVal<'_> { KeyVal(&self.0, &self.1) } - - #[must_use] - pub fn to_tuple(self) -> OwnedKeyValPair { (self.0, self.1) } -} - -impl From for OwnedKeyVal { - fn from((key, val): OwnedKeyValPair) -> Self { Self(key, val) } -} - -impl From<&KeyVal<'_>> for OwnedKeyVal { - #[inline] - fn from(slice: &KeyVal<'_>) -> Self { slice.to_owned() } -} - -impl From> for OwnedKeyVal { - fn from((key, val): KeyValPair<'_>) -> Self { Self(Vec::from(key), Vec::from(val)) } -} - -impl From for OwnedKeyValPair { - fn from(val: OwnedKeyVal) -> Self { val.to_tuple() } -} - -impl KeyVal<'_> { - #[inline] - #[must_use] - pub fn to_owned(&self) -> OwnedKeyVal { OwnedKeyVal::from(self) } - - #[must_use] - pub fn as_tuple(&self) -> KeyValPair<'_> { (self.0, self.1) } -} - -impl<'a> From<&'a OwnedKeyVal> for KeyVal<'a> { - fn from(owned: &'a OwnedKeyVal) -> Self { owned.as_slice() } -} - -impl<'a> From<&'a OwnedKeyValPair> for KeyVal<'a> { - fn from((key, val): &'a OwnedKeyValPair) -> Self { KeyVal(key.as_slice(), val.as_slice()) } -} - -impl<'a> From> for KeyVal<'a> { - fn from((key, val): KeyValPair<'a>) -> Self { KeyVal(key, val) } -} diff --git a/src/database/stream.rs b/src/database/stream.rs new file mode 100644 index 000000000..d9b74215d --- /dev/null +++ b/src/database/stream.rs @@ -0,0 +1,122 @@ +mod items; +mod items_rev; +mod keys; +mod keys_rev; + +use std::sync::Arc; + +use conduit::{utils::exchange, Error, Result}; +use rocksdb::{ColumnFamily, DBRawIteratorWithThreadMode, ReadOptions}; + +pub(crate) use self::{items::Items, items_rev::ItemsRev, keys::Keys, keys_rev::KeysRev}; +use crate::{ + engine::Db, + keyval::{Key, KeyVal, Val}, + util::map_err, + Engine, Slice, +}; + +struct State<'a> { + inner: Inner<'a>, + seek: bool, + init: bool, +} + +trait Cursor<'a, T> { + fn state(&self) -> &State<'a>; + + fn fetch(&self) -> Option; + + fn seek(&mut self); + + fn get(&self) -> Option> { + self.fetch() + .map(Ok) + .or_else(|| self.state().status().map(Err)) + } + + fn seek_and_get(&mut self) -> Option> { + self.seek(); + self.get() + } +} + +type Inner<'a> = DBRawIteratorWithThreadMode<'a, Db>; +type From<'a> = Option>; + +impl<'a> State<'a> { + fn new(db: &'a Arc, cf: &'a Arc, opts: ReadOptions) -> Self { + Self { + inner: db.db.raw_iterator_cf_opt(&**cf, opts), + init: true, + seek: false, + } + } + + fn init_fwd(mut self, from: From<'_>) -> Self { + if let Some(key) = from { + self.inner.seek(key); + self.seek = true; + } + + self + } + + fn init_rev(mut self, from: From<'_>) -> Self { + if let Some(key) = from { + self.inner.seek_for_prev(key); + self.seek = true; + } + + self + } + + fn seek_fwd(&mut self) { + if !exchange(&mut self.init, false) { + self.inner.next(); + } else if !self.seek { + self.inner.seek_to_first(); + } + } + + fn seek_rev(&mut self) { + if !exchange(&mut self.init, false) { + self.inner.prev(); + } else if !self.seek { + self.inner.seek_to_last(); + } + } + + fn fetch_key(&self) -> Option> { self.inner.key().map(Key::from) } + + fn _fetch_val(&self) -> Option> { self.inner.value().map(Val::from) } + + fn fetch(&self) -> Option> { self.inner.item().map(KeyVal::from) } + + fn status(&self) -> Option { self.inner.status().map_err(map_err).err() } + + fn valid(&self) -> bool { self.inner.valid() } +} + +fn keyval_longevity<'a, 'b: 'a>(item: KeyVal<'a>) -> KeyVal<'b> { + (slice_longevity::<'a, 'b>(item.0), slice_longevity::<'a, 'b>(item.1)) +} + +fn slice_longevity<'a, 'b: 'a>(item: &'a Slice) -> &'b Slice { + // SAFETY: The lifetime of the data returned by the rocksdb cursor is only valid + // between each movement of the cursor. It is hereby unsafely extended to match + // the lifetime of the cursor itself. This is due to the limitation of the + // Stream trait where the Item is incapable of conveying a lifetime; this is due + // to GAT's being unstable during its development. This unsafety can be removed + // as soon as this limitation is addressed by an upcoming version. + // + // We have done our best to mitigate the implications of this in conjunction + // with the deserialization API such that borrows being held across movements of + // the cursor do not happen accidentally. The compiler will still error when + // values herein produced try to leave a closure passed to a StreamExt API. But + // escapes can happen if you explicitly and intentionally attempt it, and there + // will be no compiler error or warning. This is primarily the case with + // calling collect() without a preceding map(ToOwned::to_owned). A collection + // of references here is illegal, but this will not be enforced by the compiler. + unsafe { std::mem::transmute(item) } +} diff --git a/src/database/stream/items.rs b/src/database/stream/items.rs new file mode 100644 index 000000000..31d5e9e8d --- /dev/null +++ b/src/database/stream/items.rs @@ -0,0 +1,44 @@ +use std::{pin::Pin, sync::Arc}; + +use conduit::Result; +use futures::{ + stream::FusedStream, + task::{Context, Poll}, + Stream, +}; +use rocksdb::{ColumnFamily, ReadOptions}; + +use super::{keyval_longevity, Cursor, From, State}; +use crate::{keyval::KeyVal, Engine}; + +pub(crate) struct Items<'a> { + state: State<'a>, +} + +impl<'a> Items<'a> { + pub(crate) fn new(db: &'a Arc, cf: &'a Arc, opts: ReadOptions, from: From<'_>) -> Self { + Self { + state: State::new(db, cf, opts).init_fwd(from), + } + } +} + +impl<'a> Cursor<'a, KeyVal<'a>> for Items<'a> { + fn state(&self) -> &State<'a> { &self.state } + + fn fetch(&self) -> Option> { self.state.fetch().map(keyval_longevity) } + + fn seek(&mut self) { self.state.seek_fwd(); } +} + +impl<'a> Stream for Items<'a> { + type Item = Result>; + + fn poll_next(mut self: Pin<&mut Self>, _ctx: &mut Context<'_>) -> Poll> { + Poll::Ready(self.seek_and_get()) + } +} + +impl FusedStream for Items<'_> { + fn is_terminated(&self) -> bool { !self.state.init && !self.state.valid() } +} diff --git a/src/database/stream/items_rev.rs b/src/database/stream/items_rev.rs new file mode 100644 index 000000000..ab57a2506 --- /dev/null +++ b/src/database/stream/items_rev.rs @@ -0,0 +1,44 @@ +use std::{pin::Pin, sync::Arc}; + +use conduit::Result; +use futures::{ + stream::FusedStream, + task::{Context, Poll}, + Stream, +}; +use rocksdb::{ColumnFamily, ReadOptions}; + +use super::{keyval_longevity, Cursor, From, State}; +use crate::{keyval::KeyVal, Engine}; + +pub(crate) struct ItemsRev<'a> { + state: State<'a>, +} + +impl<'a> ItemsRev<'a> { + pub(crate) fn new(db: &'a Arc, cf: &'a Arc, opts: ReadOptions, from: From<'_>) -> Self { + Self { + state: State::new(db, cf, opts).init_rev(from), + } + } +} + +impl<'a> Cursor<'a, KeyVal<'a>> for ItemsRev<'a> { + fn state(&self) -> &State<'a> { &self.state } + + fn fetch(&self) -> Option> { self.state.fetch().map(keyval_longevity) } + + fn seek(&mut self) { self.state.seek_rev(); } +} + +impl<'a> Stream for ItemsRev<'a> { + type Item = Result>; + + fn poll_next(mut self: Pin<&mut Self>, _ctx: &mut Context<'_>) -> Poll> { + Poll::Ready(self.seek_and_get()) + } +} + +impl FusedStream for ItemsRev<'_> { + fn is_terminated(&self) -> bool { !self.state.init && !self.state.valid() } +} diff --git a/src/database/stream/keys.rs b/src/database/stream/keys.rs new file mode 100644 index 000000000..1c5d12e30 --- /dev/null +++ b/src/database/stream/keys.rs @@ -0,0 +1,44 @@ +use std::{pin::Pin, sync::Arc}; + +use conduit::Result; +use futures::{ + stream::FusedStream, + task::{Context, Poll}, + Stream, +}; +use rocksdb::{ColumnFamily, ReadOptions}; + +use super::{slice_longevity, Cursor, From, State}; +use crate::{keyval::Key, Engine}; + +pub(crate) struct Keys<'a> { + state: State<'a>, +} + +impl<'a> Keys<'a> { + pub(crate) fn new(db: &'a Arc, cf: &'a Arc, opts: ReadOptions, from: From<'_>) -> Self { + Self { + state: State::new(db, cf, opts).init_fwd(from), + } + } +} + +impl<'a> Cursor<'a, Key<'a>> for Keys<'a> { + fn state(&self) -> &State<'a> { &self.state } + + fn fetch(&self) -> Option> { self.state.fetch_key().map(slice_longevity) } + + fn seek(&mut self) { self.state.seek_fwd(); } +} + +impl<'a> Stream for Keys<'a> { + type Item = Result>; + + fn poll_next(mut self: Pin<&mut Self>, _ctx: &mut Context<'_>) -> Poll> { + Poll::Ready(self.seek_and_get()) + } +} + +impl FusedStream for Keys<'_> { + fn is_terminated(&self) -> bool { !self.state.init && !self.state.valid() } +} diff --git a/src/database/stream/keys_rev.rs b/src/database/stream/keys_rev.rs new file mode 100644 index 000000000..267074837 --- /dev/null +++ b/src/database/stream/keys_rev.rs @@ -0,0 +1,44 @@ +use std::{pin::Pin, sync::Arc}; + +use conduit::Result; +use futures::{ + stream::FusedStream, + task::{Context, Poll}, + Stream, +}; +use rocksdb::{ColumnFamily, ReadOptions}; + +use super::{slice_longevity, Cursor, From, State}; +use crate::{keyval::Key, Engine}; + +pub(crate) struct KeysRev<'a> { + state: State<'a>, +} + +impl<'a> KeysRev<'a> { + pub(crate) fn new(db: &'a Arc, cf: &'a Arc, opts: ReadOptions, from: From<'_>) -> Self { + Self { + state: State::new(db, cf, opts).init_rev(from), + } + } +} + +impl<'a> Cursor<'a, Key<'a>> for KeysRev<'a> { + fn state(&self) -> &State<'a> { &self.state } + + fn fetch(&self) -> Option> { self.state.fetch_key().map(slice_longevity) } + + fn seek(&mut self) { self.state.seek_rev(); } +} + +impl<'a> Stream for KeysRev<'a> { + type Item = Result>; + + fn poll_next(mut self: Pin<&mut Self>, _ctx: &mut Context<'_>) -> Poll> { + Poll::Ready(self.seek_and_get()) + } +} + +impl FusedStream for KeysRev<'_> { + fn is_terminated(&self) -> bool { !self.state.init && !self.state.valid() } +} diff --git a/src/database/util.rs b/src/database/util.rs index f0ccbcbee..d36e183f4 100644 --- a/src/database/util.rs +++ b/src/database/util.rs @@ -1,4 +1,16 @@ use conduit::{err, Result}; +use rocksdb::{Direction, IteratorMode}; + +#[inline] +pub(crate) fn _into_direction(mode: &IteratorMode<'_>) -> Direction { + use Direction::{Forward, Reverse}; + use IteratorMode::{End, From, Start}; + + match mode { + Start | From(_, Forward) => Forward, + End | From(_, Reverse) => Reverse, + } +} #[inline] pub(crate) fn result(r: std::result::Result) -> Result { diff --git a/src/macros/utils.rs b/src/macros/utils.rs index 58074e3a0..197dd90e9 100644 --- a/src/macros/utils.rs +++ b/src/macros/utils.rs @@ -41,8 +41,5 @@ pub(crate) fn camel_to_snake_string(s: &str) -> String { output } -pub(crate) fn exchange(state: &mut T, source: T) -> T { - let ret = state.clone(); - *state = source; - ret -} +#[inline] +pub(crate) fn exchange(state: &mut T, source: T) -> T { std::mem::replace(state, source) } diff --git a/src/router/serve/unix.rs b/src/router/serve/unix.rs index fb011f188..5df41b614 100644 --- a/src/router/serve/unix.rs +++ b/src/router/serve/unix.rs @@ -10,7 +10,7 @@ use axum::{ extract::{connect_info::IntoMakeServiceWithConnectInfo, Request}, Router, }; -use conduit::{debug, debug_error, error::infallible, info, trace, warn, Err, Result, Server}; +use conduit::{debug, debug_error, info, result::UnwrapInfallible, trace, warn, Err, Result, Server}; use hyper::{body::Incoming, service::service_fn}; use hyper_util::{ rt::{TokioExecutor, TokioIo}, @@ -62,11 +62,7 @@ async fn accept( let socket = TokioIo::new(socket); trace!(?listener, ?socket, ?remote, "accepted"); - let called = app - .call(NULL_ADDR) - .await - .inspect_err(infallible) - .expect("infallible"); + let called = app.call(NULL_ADDR).await.unwrap_infallible(); let service = move |req: Request| called.clone().oneshot(req); let handler = service_fn(service); diff --git a/src/service/Cargo.toml b/src/service/Cargo.toml index cfed5a0e3..737a70399 100644 --- a/src/service/Cargo.toml +++ b/src/service/Cargo.toml @@ -46,7 +46,7 @@ bytes.workspace = true conduit-core.workspace = true conduit-database.workspace = true const-str.workspace = true -futures-util.workspace = true +futures.workspace = true hickory-resolver.workspace = true http.workspace = true image.workspace = true diff --git a/src/service/account_data/data.rs b/src/service/account_data/data.rs deleted file mode 100644 index 53a0e9533..000000000 --- a/src/service/account_data/data.rs +++ /dev/null @@ -1,152 +0,0 @@ -use std::{collections::HashMap, sync::Arc}; - -use conduit::{Error, Result}; -use database::Map; -use ruma::{ - api::client::error::ErrorKind, - events::{AnyGlobalAccountDataEvent, AnyRawAccountDataEvent, AnyRoomAccountDataEvent, RoomAccountDataEventType}, - serde::Raw, - RoomId, UserId, -}; - -use crate::{globals, Dep}; - -pub(super) struct Data { - roomuserdataid_accountdata: Arc, - roomusertype_roomuserdataid: Arc, - services: Services, -} - -struct Services { - globals: Dep, -} - -impl Data { - pub(super) fn new(args: &crate::Args<'_>) -> Self { - let db = &args.db; - Self { - roomuserdataid_accountdata: db["roomuserdataid_accountdata"].clone(), - roomusertype_roomuserdataid: db["roomusertype_roomuserdataid"].clone(), - services: Services { - globals: args.depend::("globals"), - }, - } - } - - /// Places one event in the account data of the user and removes the - /// previous entry. - pub(super) fn update( - &self, room_id: Option<&RoomId>, user_id: &UserId, event_type: &RoomAccountDataEventType, - data: &serde_json::Value, - ) -> Result<()> { - let mut prefix = room_id - .map(ToString::to_string) - .unwrap_or_default() - .as_bytes() - .to_vec(); - prefix.push(0xFF); - prefix.extend_from_slice(user_id.as_bytes()); - prefix.push(0xFF); - - let mut roomuserdataid = prefix.clone(); - roomuserdataid.extend_from_slice(&self.services.globals.next_count()?.to_be_bytes()); - roomuserdataid.push(0xFF); - roomuserdataid.extend_from_slice(event_type.to_string().as_bytes()); - - let mut key = prefix; - key.extend_from_slice(event_type.to_string().as_bytes()); - - if data.get("type").is_none() || data.get("content").is_none() { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Account data doesn't have all required fields.", - )); - } - - self.roomuserdataid_accountdata.insert( - &roomuserdataid, - &serde_json::to_vec(&data).expect("to_vec always works on json values"), - )?; - - let prev = self.roomusertype_roomuserdataid.get(&key)?; - - self.roomusertype_roomuserdataid - .insert(&key, &roomuserdataid)?; - - // Remove old entry - if let Some(prev) = prev { - self.roomuserdataid_accountdata.remove(&prev)?; - } - - Ok(()) - } - - /// Searches the account data for a specific kind. - pub(super) fn get( - &self, room_id: Option<&RoomId>, user_id: &UserId, kind: &RoomAccountDataEventType, - ) -> Result>> { - let mut key = room_id - .map(ToString::to_string) - .unwrap_or_default() - .as_bytes() - .to_vec(); - key.push(0xFF); - key.extend_from_slice(user_id.as_bytes()); - key.push(0xFF); - key.extend_from_slice(kind.to_string().as_bytes()); - - self.roomusertype_roomuserdataid - .get(&key)? - .and_then(|roomuserdataid| { - self.roomuserdataid_accountdata - .get(&roomuserdataid) - .transpose() - }) - .transpose()? - .map(|data| serde_json::from_slice(&data).map_err(|_| Error::bad_database("could not deserialize"))) - .transpose() - } - - /// Returns all changes to the account data that happened after `since`. - pub(super) fn changes_since( - &self, room_id: Option<&RoomId>, user_id: &UserId, since: u64, - ) -> Result> { - let mut userdata = HashMap::new(); - - let mut prefix = room_id - .map(ToString::to_string) - .unwrap_or_default() - .as_bytes() - .to_vec(); - prefix.push(0xFF); - prefix.extend_from_slice(user_id.as_bytes()); - prefix.push(0xFF); - - // Skip the data that's exactly at since, because we sent that last time - let mut first_possible = prefix.clone(); - first_possible.extend_from_slice(&(since.saturating_add(1)).to_be_bytes()); - - for r in self - .roomuserdataid_accountdata - .iter_from(&first_possible, false) - .take_while(move |(k, _)| k.starts_with(&prefix)) - .map(|(k, v)| { - Ok::<_, Error>(( - k, - match room_id { - None => serde_json::from_slice::>(&v) - .map(AnyRawAccountDataEvent::Global) - .map_err(|_| Error::bad_database("Database contains invalid account data."))?, - Some(_) => serde_json::from_slice::>(&v) - .map(AnyRawAccountDataEvent::Room) - .map_err(|_| Error::bad_database("Database contains invalid account data."))?, - }, - )) - }) { - let (kind, data) = r?; - userdata.insert(kind, data); - } - - Ok(userdata.into_values().collect()) - } -} diff --git a/src/service/account_data/mod.rs b/src/service/account_data/mod.rs index eaa536417..482229e7f 100644 --- a/src/service/account_data/mod.rs +++ b/src/service/account_data/mod.rs @@ -1,52 +1,158 @@ -mod data; +use std::{collections::HashMap, sync::Arc}; -use std::sync::Arc; - -use conduit::Result; -use data::Data; +use conduit::{ + implement, + utils::{stream::TryIgnore, ReadyExt}, + Err, Error, Result, +}; +use database::{Deserialized, Map}; +use futures::{StreamExt, TryFutureExt}; use ruma::{ - events::{AnyRawAccountDataEvent, RoomAccountDataEventType}, + events::{AnyGlobalAccountDataEvent, AnyRawAccountDataEvent, AnyRoomAccountDataEvent, RoomAccountDataEventType}, + serde::Raw, RoomId, UserId, }; +use serde_json::value::RawValue; + +use crate::{globals, Dep}; pub struct Service { + services: Services, db: Data, } +struct Data { + roomuserdataid_accountdata: Arc, + roomusertype_roomuserdataid: Arc, +} + +struct Services { + globals: Dep, +} + impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { - db: Data::new(&args), + services: Services { + globals: args.depend::("globals"), + }, + db: Data { + roomuserdataid_accountdata: args.db["roomuserdataid_accountdata"].clone(), + roomusertype_roomuserdataid: args.db["roomusertype_roomuserdataid"].clone(), + }, })) } fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } -impl Service { - /// Places one event in the account data of the user and removes the - /// previous entry. - #[allow(clippy::needless_pass_by_value)] - pub fn update( - &self, room_id: Option<&RoomId>, user_id: &UserId, event_type: RoomAccountDataEventType, - data: &serde_json::Value, - ) -> Result<()> { - self.db.update(room_id, user_id, &event_type, data) - } +/// Places one event in the account data of the user and removes the +/// previous entry. +#[allow(clippy::needless_pass_by_value)] +#[implement(Service)] +pub async fn update( + &self, room_id: Option<&RoomId>, user_id: &UserId, event_type: RoomAccountDataEventType, data: &serde_json::Value, +) -> Result<()> { + let event_type = event_type.to_string(); + let count = self.services.globals.next_count()?; + + let mut prefix = room_id + .map(ToString::to_string) + .unwrap_or_default() + .as_bytes() + .to_vec(); + prefix.push(0xFF); + prefix.extend_from_slice(user_id.as_bytes()); + prefix.push(0xFF); - /// Searches the account data for a specific kind. - #[allow(clippy::needless_pass_by_value)] - pub fn get( - &self, room_id: Option<&RoomId>, user_id: &UserId, event_type: RoomAccountDataEventType, - ) -> Result>> { - self.db.get(room_id, user_id, &event_type) + let mut roomuserdataid = prefix.clone(); + roomuserdataid.extend_from_slice(&count.to_be_bytes()); + roomuserdataid.push(0xFF); + roomuserdataid.extend_from_slice(event_type.as_bytes()); + + let mut key = prefix; + key.extend_from_slice(event_type.as_bytes()); + + if data.get("type").is_none() || data.get("content").is_none() { + return Err!(Request(InvalidParam("Account data doesn't have all required fields."))); } - /// Returns all changes to the account data that happened after `since`. - #[tracing::instrument(skip_all, name = "since", level = "debug")] - pub fn changes_since( - &self, room_id: Option<&RoomId>, user_id: &UserId, since: u64, - ) -> Result> { - self.db.changes_since(room_id, user_id, since) + self.db.roomuserdataid_accountdata.insert( + &roomuserdataid, + &serde_json::to_vec(&data).expect("to_vec always works on json values"), + ); + + let prev_key = (room_id, user_id, &event_type); + let prev = self.db.roomusertype_roomuserdataid.qry(&prev_key).await; + + self.db + .roomusertype_roomuserdataid + .insert(&key, &roomuserdataid); + + // Remove old entry + if let Ok(prev) = prev { + self.db.roomuserdataid_accountdata.remove(&prev); } + + Ok(()) +} + +/// Searches the account data for a specific kind. +#[implement(Service)] +pub async fn get( + &self, room_id: Option<&RoomId>, user_id: &UserId, kind: RoomAccountDataEventType, +) -> Result> { + let key = (room_id, user_id, kind.to_string()); + self.db + .roomusertype_roomuserdataid + .qry(&key) + .and_then(|roomuserdataid| self.db.roomuserdataid_accountdata.get(&roomuserdataid)) + .await + .deserialized() +} + +/// Returns all changes to the account data that happened after `since`. +#[implement(Service)] +pub async fn changes_since( + &self, room_id: Option<&RoomId>, user_id: &UserId, since: u64, +) -> Result> { + let mut userdata = HashMap::new(); + + let mut prefix = room_id + .map(ToString::to_string) + .unwrap_or_default() + .as_bytes() + .to_vec(); + prefix.push(0xFF); + prefix.extend_from_slice(user_id.as_bytes()); + prefix.push(0xFF); + + // Skip the data that's exactly at since, because we sent that last time + let mut first_possible = prefix.clone(); + first_possible.extend_from_slice(&(since.saturating_add(1)).to_be_bytes()); + + self.db + .roomuserdataid_accountdata + .raw_stream_from(&first_possible) + .ignore_err() + .ready_take_while(move |(k, _)| k.starts_with(&prefix)) + .map(|(k, v)| { + let v = match room_id { + None => serde_json::from_slice::>(v) + .map(AnyRawAccountDataEvent::Global) + .map_err(|_| Error::bad_database("Database contains invalid account data."))?, + Some(_) => serde_json::from_slice::>(v) + .map(AnyRawAccountDataEvent::Room) + .map_err(|_| Error::bad_database("Database contains invalid account data."))?, + }; + + Ok((k.to_owned(), v)) + }) + .ignore_err() + .ready_for_each(|(kind, data)| { + userdata.insert(kind, data); + }) + .await; + + Ok(userdata.into_values().collect()) } diff --git a/src/service/admin/console.rs b/src/service/admin/console.rs index 55bae3658..0f5016e15 100644 --- a/src/service/admin/console.rs +++ b/src/service/admin/console.rs @@ -5,7 +5,7 @@ use std::{ }; use conduit::{debug, defer, error, log, Server}; -use futures_util::future::{AbortHandle, Abortable}; +use futures::future::{AbortHandle, Abortable}; use ruma::events::room::message::RoomMessageEventContent; use rustyline_async::{Readline, ReadlineError, ReadlineEvent}; use termimad::MadSkin; diff --git a/src/service/admin/create.rs b/src/service/admin/create.rs index 4e2b831c5..3dd5aea35 100644 --- a/src/service/admin/create.rs +++ b/src/service/admin/create.rs @@ -30,7 +30,11 @@ use crate::Services; pub async fn create_admin_room(services: &Services) -> Result<()> { let room_id = RoomId::new(services.globals.server_name()); - let _short_id = services.rooms.short.get_or_create_shortroomid(&room_id)?; + let _short_id = services + .rooms + .short + .get_or_create_shortroomid(&room_id) + .await; let state_lock = services.rooms.state.mutex.lock(&room_id).await; diff --git a/src/service/admin/grant.rs b/src/service/admin/grant.rs index b4589ebc8..4b3ebb887 100644 --- a/src/service/admin/grant.rs +++ b/src/service/admin/grant.rs @@ -17,108 +17,108 @@ use serde_json::value::to_raw_value; use crate::pdu::PduBuilder; -impl super::Service { - /// Invite the user to the conduit admin room. - /// - /// In conduit, this is equivalent to granting admin privileges. - pub async fn make_user_admin(&self, user_id: &UserId) -> Result<()> { - let Some(room_id) = self.get_admin_room()? else { - return Ok(()); - }; +/// Invite the user to the conduit admin room. +/// +/// In conduit, this is equivalent to granting admin privileges. +#[implement(super::Service)] +pub async fn make_user_admin(&self, user_id: &UserId) -> Result<()> { + let Ok(room_id) = self.get_admin_room().await else { + return Ok(()); + }; - let state_lock = self.services.state.mutex.lock(&room_id).await; + let state_lock = self.services.state.mutex.lock(&room_id).await; - // Use the server user to grant the new admin's power level - let server_user = &self.services.globals.server_user; + // Use the server user to grant the new admin's power level + let server_user = &self.services.globals.server_user; - // Invite and join the real user - self.services - .timeline - .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomMember, - content: to_raw_value(&RoomMemberEventContent { - membership: MembershipState::Invite, - displayname: None, - avatar_url: None, - is_direct: None, - third_party_invite: None, - blurhash: None, - reason: None, - join_authorized_via_users_server: None, - }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(user_id.to_string()), - redacts: None, - timestamp: None, - }, - server_user, - &room_id, - &state_lock, - ) - .await?; - self.services - .timeline - .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomMember, - content: to_raw_value(&RoomMemberEventContent { - membership: MembershipState::Join, - displayname: None, - avatar_url: None, - is_direct: None, - third_party_invite: None, - blurhash: None, - reason: None, - join_authorized_via_users_server: None, - }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(user_id.to_string()), - redacts: None, - timestamp: None, - }, - user_id, - &room_id, - &state_lock, - ) - .await?; + // Invite and join the real user + self.services + .timeline + .build_and_append_pdu( + PduBuilder { + event_type: TimelineEventType::RoomMember, + content: to_raw_value(&RoomMemberEventContent { + membership: MembershipState::Invite, + displayname: None, + avatar_url: None, + is_direct: None, + third_party_invite: None, + blurhash: None, + reason: None, + join_authorized_via_users_server: None, + }) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: Some(user_id.to_string()), + redacts: None, + timestamp: None, + }, + server_user, + &room_id, + &state_lock, + ) + .await?; + self.services + .timeline + .build_and_append_pdu( + PduBuilder { + event_type: TimelineEventType::RoomMember, + content: to_raw_value(&RoomMemberEventContent { + membership: MembershipState::Join, + displayname: None, + avatar_url: None, + is_direct: None, + third_party_invite: None, + blurhash: None, + reason: None, + join_authorized_via_users_server: None, + }) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: Some(user_id.to_string()), + redacts: None, + timestamp: None, + }, + user_id, + &room_id, + &state_lock, + ) + .await?; - // Set power level - let users = BTreeMap::from_iter([(server_user.clone(), 100.into()), (user_id.to_owned(), 100.into())]); + // Set power level + let users = BTreeMap::from_iter([(server_user.clone(), 100.into()), (user_id.to_owned(), 100.into())]); - self.services - .timeline - .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomPowerLevels, - content: to_raw_value(&RoomPowerLevelsEventContent { - users, - ..Default::default() - }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(String::new()), - redacts: None, - timestamp: None, - }, - server_user, - &room_id, - &state_lock, - ) - .await?; + self.services + .timeline + .build_and_append_pdu( + PduBuilder { + event_type: TimelineEventType::RoomPowerLevels, + content: to_raw_value(&RoomPowerLevelsEventContent { + users, + ..Default::default() + }) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: Some(String::new()), + redacts: None, + timestamp: None, + }, + server_user, + &room_id, + &state_lock, + ) + .await?; - // Set room tag - let room_tag = &self.services.server.config.admin_room_tag; - if !room_tag.is_empty() { - if let Err(e) = self.set_room_tag(&room_id, user_id, room_tag) { - error!(?room_id, ?user_id, ?room_tag, ?e, "Failed to set tag for admin grant"); - } + // Set room tag + let room_tag = &self.services.server.config.admin_room_tag; + if !room_tag.is_empty() { + if let Err(e) = self.set_room_tag(&room_id, user_id, room_tag).await { + error!(?room_id, ?user_id, ?room_tag, ?e, "Failed to set tag for admin grant"); } + } - // Send welcome message - self.services.timeline.build_and_append_pdu( + // Send welcome message + self.services.timeline.build_and_append_pdu( PduBuilder { event_type: TimelineEventType::RoomMessage, content: to_raw_value(&RoomMessageEventContent::text_markdown( @@ -135,19 +135,18 @@ impl super::Service { &state_lock, ).await?; - Ok(()) - } + Ok(()) } #[implement(super::Service)] -fn set_room_tag(&self, room_id: &RoomId, user_id: &UserId, tag: &str) -> Result<()> { +async fn set_room_tag(&self, room_id: &RoomId, user_id: &UserId, tag: &str) -> Result<()> { let mut event = self .services .account_data - .get(Some(room_id), user_id, RoomAccountDataEventType::Tag)? - .map(|event| serde_json::from_str(event.get())) - .and_then(Result::ok) - .unwrap_or_else(|| TagEvent { + .get(Some(room_id), user_id, RoomAccountDataEventType::Tag) + .await + .and_then(|event| serde_json::from_str(event.get()).map_err(Into::into)) + .unwrap_or_else(|_| TagEvent { content: TagEventContent { tags: BTreeMap::new(), }, @@ -158,12 +157,15 @@ fn set_room_tag(&self, room_id: &RoomId, user_id: &UserId, tag: &str) -> Result< .tags .insert(tag.to_owned().into(), TagInfo::new()); - self.services.account_data.update( - Some(room_id), - user_id, - RoomAccountDataEventType::Tag, - &serde_json::to_value(event)?, - )?; + self.services + .account_data + .update( + Some(room_id), + user_id, + RoomAccountDataEventType::Tag, + &serde_json::to_value(event)?, + ) + .await?; Ok(()) } diff --git a/src/service/admin/mod.rs b/src/service/admin/mod.rs index 3274249e6..12eacc8fa 100644 --- a/src/service/admin/mod.rs +++ b/src/service/admin/mod.rs @@ -12,6 +12,7 @@ use std::{ use async_trait::async_trait; use conduit::{debug, err, error, error::default_log, pdu::PduBuilder, Error, PduEvent, Result, Server}; pub use create::create_admin_room; +use futures::{FutureExt, TryFutureExt}; use loole::{Receiver, Sender}; use ruma::{ events::{ @@ -142,17 +143,18 @@ impl Service { /// admin room as the admin user. pub async fn send_text(&self, body: &str) { self.send_message(RoomMessageEventContent::text_markdown(body)) - .await; + .await + .ok(); } /// Sends a message to the admin room as the admin user (see send_text() for /// convenience). - pub async fn send_message(&self, message_content: RoomMessageEventContent) { - if let Ok(Some(room_id)) = self.get_admin_room() { - let user_id = &self.services.globals.server_user; - self.respond_to_room(message_content, &room_id, user_id) - .await; - } + pub async fn send_message(&self, message_content: RoomMessageEventContent) -> Result<()> { + let user_id = &self.services.globals.server_user; + let room_id = self.get_admin_room().await?; + self.respond_to_room(message_content, &room_id, user_id) + .boxed() + .await } /// Posts a command to the command processor queue and returns. Processing @@ -193,8 +195,12 @@ impl Service { async fn handle_command(&self, command: CommandInput) { match self.process_command(command).await { - Ok(Some(output)) | Err(output) => self.handle_response(output).await, Ok(None) => debug!("Command successful with no response"), + Ok(Some(output)) | Err(output) => self + .handle_response(output) + .boxed() + .await + .unwrap_or_else(default_log), } } @@ -218,71 +224,67 @@ impl Service { } /// Checks whether a given user is an admin of this server - pub async fn user_is_admin(&self, user_id: &UserId) -> Result { - if let Ok(Some(admin_room)) = self.get_admin_room() { - self.services.state_cache.is_joined(user_id, &admin_room) - } else { - Ok(false) - } + pub async fn user_is_admin(&self, user_id: &UserId) -> bool { + let Ok(admin_room) = self.get_admin_room().await else { + return false; + }; + + self.services + .state_cache + .is_joined(user_id, &admin_room) + .await } /// Gets the room ID of the admin room /// /// Errors are propagated from the database, and will have None if there is /// no admin room - pub fn get_admin_room(&self) -> Result> { - if let Some(room_id) = self + pub async fn get_admin_room(&self) -> Result { + let room_id = self .services .alias - .resolve_local_alias(&self.services.globals.admin_alias)? - { - if self - .services - .state_cache - .is_joined(&self.services.globals.server_user, &room_id)? - { - return Ok(Some(room_id)); - } - } + .resolve_local_alias(&self.services.globals.admin_alias) + .await?; - Ok(None) + self.services + .state_cache + .is_joined(&self.services.globals.server_user, &room_id) + .await + .then_some(room_id) + .ok_or_else(|| err!(Request(NotFound("Admin user not joined to admin room")))) } - async fn handle_response(&self, content: RoomMessageEventContent) { + async fn handle_response(&self, content: RoomMessageEventContent) -> Result<()> { let Some(Relation::Reply { in_reply_to, }) = content.relates_to.as_ref() else { - return; + return Ok(()); }; - let Ok(Some(pdu)) = self.services.timeline.get_pdu(&in_reply_to.event_id) else { + let Ok(pdu) = self.services.timeline.get_pdu(&in_reply_to.event_id).await else { error!( event_id = ?in_reply_to.event_id, "Missing admin command in_reply_to event" ); - return; + return Ok(()); }; - let response_sender = if self.is_admin_room(&pdu.room_id) { + let response_sender = if self.is_admin_room(&pdu.room_id).await { &self.services.globals.server_user } else { &pdu.sender }; self.respond_to_room(content, &pdu.room_id, response_sender) - .await; + .await } - async fn respond_to_room(&self, content: RoomMessageEventContent, room_id: &RoomId, user_id: &UserId) { - assert!( - self.user_is_admin(user_id) - .await - .expect("checked user is admin"), - "sender is not admin" - ); + async fn respond_to_room( + &self, content: RoomMessageEventContent, room_id: &RoomId, user_id: &UserId, + ) -> Result<()> { + assert!(self.user_is_admin(user_id).await, "sender is not admin"); - let state_lock = self.services.state.mutex.lock(room_id).await; let response_pdu = PduBuilder { event_type: TimelineEventType::RoomMessage, content: to_raw_value(&content).expect("event is valid, we just created it"), @@ -292,6 +294,7 @@ impl Service { timestamp: None, }; + let state_lock = self.services.state.mutex.lock(room_id).await; if let Err(e) = self .services .timeline @@ -302,6 +305,8 @@ impl Service { .await .unwrap_or_else(default_log); } + + Ok(()) } async fn handle_response_error( @@ -355,12 +360,12 @@ impl Service { } // Prevent unescaped !admin from being used outside of the admin room - if is_public_prefix && !self.is_admin_room(&pdu.room_id) { + if is_public_prefix && !self.is_admin_room(&pdu.room_id).await { return false; } // Only senders who are admin can proceed - if !self.user_is_admin(&pdu.sender).await.unwrap_or(false) { + if !self.user_is_admin(&pdu.sender).await { return false; } @@ -368,7 +373,7 @@ impl Service { // the administrator can execute commands as conduit let emergency_password_set = self.services.globals.emergency_password().is_some(); let from_server = pdu.sender == *server_user && !emergency_password_set; - if from_server && self.is_admin_room(&pdu.room_id) { + if from_server && self.is_admin_room(&pdu.room_id).await { return false; } @@ -377,12 +382,11 @@ impl Service { } #[must_use] - pub fn is_admin_room(&self, room_id: &RoomId) -> bool { - if let Ok(Some(admin_room_id)) = self.get_admin_room() { - admin_room_id == room_id - } else { - false - } + pub async fn is_admin_room(&self, room_id_: &RoomId) -> bool { + self.get_admin_room() + .map_ok(|room_id| room_id == room_id_) + .await + .unwrap_or(false) } /// Sets the self-reference to crate::Services which will provide context to diff --git a/src/service/appservice/data.rs b/src/service/appservice/data.rs index 40e641a1e..8fb7d9582 100644 --- a/src/service/appservice/data.rs +++ b/src/service/appservice/data.rs @@ -1,7 +1,8 @@ use std::sync::Arc; -use conduit::{utils, Error, Result}; +use conduit::{err, utils::stream::TryIgnore, Result}; use database::{Database, Map}; +use futures::Stream; use ruma::api::appservice::Registration; pub struct Data { @@ -19,7 +20,7 @@ impl Data { pub(super) fn register_appservice(&self, yaml: &Registration) -> Result { let id = yaml.id.as_str(); self.id_appserviceregistrations - .insert(id.as_bytes(), serde_yaml::to_string(&yaml).unwrap().as_bytes())?; + .insert(id.as_bytes(), serde_yaml::to_string(&yaml).unwrap().as_bytes()); Ok(id.to_owned()) } @@ -31,24 +32,19 @@ impl Data { /// * `service_name` - the name you send to register the service previously pub(super) fn unregister_appservice(&self, service_name: &str) -> Result<()> { self.id_appserviceregistrations - .remove(service_name.as_bytes())?; + .remove(service_name.as_bytes()); Ok(()) } - pub fn get_registration(&self, id: &str) -> Result> { + pub async fn get_registration(&self, id: &str) -> Result { self.id_appserviceregistrations - .get(id.as_bytes())? - .map(|bytes| { - serde_yaml::from_slice(&bytes) - .map_err(|_| Error::bad_database("Invalid registration bytes in id_appserviceregistrations.")) - }) - .transpose() + .get(id) + .await + .and_then(|ref bytes| serde_yaml::from_slice(bytes).map_err(Into::into)) + .map_err(|e| err!(Database("Invalid appservice {id:?} registration: {e:?}"))) } - pub(super) fn iter_ids<'a>(&'a self) -> Result> + 'a>> { - Ok(Box::new(self.id_appserviceregistrations.iter().map(|(id, _)| { - utils::string_from_bytes(&id) - .map_err(|_| Error::bad_database("Invalid id bytes in id_appserviceregistrations.")) - }))) + pub(super) fn iter_ids(&self) -> impl Stream + Send + '_ { + self.id_appserviceregistrations.keys().ignore_err() } } diff --git a/src/service/appservice/mod.rs b/src/service/appservice/mod.rs index c0752d565..7e2dc7387 100644 --- a/src/service/appservice/mod.rs +++ b/src/service/appservice/mod.rs @@ -2,9 +2,10 @@ mod data; use std::{collections::BTreeMap, sync::Arc}; +use async_trait::async_trait; use conduit::{err, Result}; use data::Data; -use futures_util::Future; +use futures::{Future, StreamExt, TryStreamExt}; use regex::RegexSet; use ruma::{ api::appservice::{Namespace, Registration}, @@ -126,13 +127,22 @@ struct Services { sending: Dep, } +#[async_trait] impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { - let mut registration_info = BTreeMap::new(); - let db = Data::new(args.db); + Ok(Arc::new(Self { + db: Data::new(args.db), + services: Services { + sending: args.depend::("sending"), + }, + registration_info: RwLock::new(BTreeMap::new()), + })) + } + + async fn worker(self: Arc) -> Result<()> { // Inserting registrations into cache - for appservice in iter_ids(&db)? { - registration_info.insert( + for appservice in iter_ids(&self.db).await? { + self.registration_info.write().await.insert( appservice.0, appservice .1 @@ -141,13 +151,7 @@ impl crate::Service for Service { ); } - Ok(Arc::new(Self { - db, - services: Services { - sending: args.depend::("sending"), - }, - registration_info: RwLock::new(registration_info), - })) + Ok(()) } fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } @@ -155,7 +159,7 @@ impl crate::Service for Service { impl Service { #[inline] - pub fn all(&self) -> Result> { iter_ids(&self.db) } + pub async fn all(&self) -> Result> { iter_ids(&self.db).await } /// Registers an appservice and returns the ID to the caller pub async fn register_appservice(&self, yaml: Registration) -> Result { @@ -188,7 +192,8 @@ impl Service { // sending to the URL self.services .sending - .cleanup_events(service_name.to_owned())?; + .cleanup_events(service_name.to_owned()) + .await; Ok(()) } @@ -251,15 +256,9 @@ impl Service { } } -fn iter_ids(db: &Data) -> Result> { - db.iter_ids()? - .filter_map(Result::ok) - .map(move |id| { - Ok(( - id.clone(), - db.get_registration(&id)? - .expect("iter_ids only returns appservices that exist"), - )) - }) - .collect() +async fn iter_ids(db: &Data) -> Result> { + db.iter_ids() + .then(|id| async move { Ok((id.clone(), db.get_registration(&id).await?)) }) + .try_collect() + .await } diff --git a/src/service/emergency/mod.rs b/src/service/emergency/mod.rs index 1bb0843d4..c99a0891e 100644 --- a/src/service/emergency/mod.rs +++ b/src/service/emergency/mod.rs @@ -32,7 +32,12 @@ impl crate::Service for Service { } async fn worker(self: Arc) -> Result<()> { + if self.services.globals.is_read_only() { + return Ok(()); + } + self.set_emergency_access() + .await .inspect_err(|e| error!("Could not set the configured emergency password for the conduit user: {e}"))?; Ok(()) @@ -44,7 +49,7 @@ impl crate::Service for Service { impl Service { /// Sets the emergency password and push rules for the @conduit account in /// case emergency password is set - fn set_emergency_access(&self) -> Result { + async fn set_emergency_access(&self) -> Result { let conduit_user = &self.services.globals.server_user; self.services @@ -56,17 +61,20 @@ impl Service { None => (Ruleset::new(), false), }; - self.services.account_data.update( - None, - conduit_user, - GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(&GlobalAccountDataEvent { - content: PushRulesEventContent { - global: ruleset, - }, - }) - .expect("to json value always works"), - )?; + self.services + .account_data + .update( + None, + conduit_user, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(&GlobalAccountDataEvent { + content: PushRulesEventContent { + global: ruleset, + }, + }) + .expect("to json value always works"), + ) + .await?; if pwd_set { warn!( @@ -75,7 +83,7 @@ impl Service { ); } else { // logs out any users still in the server service account and removes sessions - self.services.users.deactivate_account(conduit_user)?; + self.services.users.deactivate_account(conduit_user).await?; } Ok(pwd_set) diff --git a/src/service/globals/data.rs b/src/service/globals/data.rs index 5b5d9f09d..57a295d99 100644 --- a/src/service/globals/data.rs +++ b/src/service/globals/data.rs @@ -3,9 +3,9 @@ use std::{ sync::{Arc, RwLock}, }; -use conduit::{trace, utils, Error, Result, Server}; -use database::{Database, Map}; -use futures_util::{stream::FuturesUnordered, StreamExt}; +use conduit::{trace, utils, utils::rand, Error, Result, Server}; +use database::{Database, Deserialized, Map}; +use futures::{pin_mut, stream::FuturesUnordered, FutureExt, StreamExt}; use ruma::{ api::federation::discovery::{ServerSigningKeys, VerifyKey}, signatures::Ed25519KeyPair, @@ -83,7 +83,7 @@ impl Data { .checked_add(1) .expect("counter must not overflow u64"); - self.global.insert(COUNTER, &counter.to_be_bytes())?; + self.global.insert(COUNTER, &counter.to_be_bytes()); Ok(*counter) } @@ -102,7 +102,7 @@ impl Data { fn stored_count(global: &Arc) -> Result { global - .get(COUNTER)? + .get_blocking(COUNTER) .as_deref() .map_or(Ok(0_u64), utils::u64_from_bytes) } @@ -133,36 +133,18 @@ impl Data { futures.push(self.userroomid_highlightcount.watch_prefix(&userid_prefix)); // Events for rooms we are in - for room_id in self - .services - .state_cache - .rooms_joined(user_id) - .filter_map(Result::ok) - { - let short_roomid = self - .services - .short - .get_shortroomid(&room_id) - .ok() - .flatten() - .expect("room exists") - .to_be_bytes() - .to_vec(); + let rooms_joined = self.services.state_cache.rooms_joined(user_id); + + pin_mut!(rooms_joined); + while let Some(room_id) = rooms_joined.next().await { + let Ok(short_roomid) = self.services.short.get_shortroomid(room_id).await else { + continue; + }; let roomid_bytes = room_id.as_bytes().to_vec(); let mut roomid_prefix = roomid_bytes.clone(); roomid_prefix.push(0xFF); - // PDUs - futures.push(self.pduid_pdu.watch_prefix(&short_roomid)); - - // EDUs - futures.push(Box::pin(async move { - let _result = self.services.typing.wait_for_update(&room_id).await; - })); - - futures.push(self.readreceiptid_readreceipt.watch_prefix(&roomid_prefix)); - // Key changes futures.push(self.keychangeid_userid.watch_prefix(&roomid_prefix)); @@ -174,6 +156,19 @@ impl Data { self.roomusertype_roomuserdataid .watch_prefix(&roomuser_prefix), ); + + // PDUs + let short_roomid = short_roomid.to_be_bytes().to_vec(); + futures.push(self.pduid_pdu.watch_prefix(&short_roomid)); + + // EDUs + let typing_room_id = room_id.to_owned(); + let typing_wait_for_update = async move { + self.services.typing.wait_for_update(&typing_room_id).await; + }; + + futures.push(typing_wait_for_update.boxed()); + futures.push(self.readreceiptid_readreceipt.watch_prefix(&roomid_prefix)); } let mut globaluserdata_prefix = vec![0xFF]; @@ -190,12 +185,14 @@ impl Data { // One time keys futures.push(self.userid_lastonetimekeyupdate.watch_prefix(&userid_bytes)); - futures.push(Box::pin(async move { + // Server shutdown + let server_shutdown = async move { while self.services.server.running() { - let _result = self.services.server.signal.subscribe().recv().await; + self.services.server.signal.subscribe().recv().await.ok(); } - })); + }; + futures.push(server_shutdown.boxed()); if !self.services.server.running() { return Ok(()); } @@ -209,17 +206,23 @@ impl Data { } pub fn load_keypair(&self) -> Result { - let keypair_bytes = self.global.get(b"keypair")?.map_or_else( - || { - let keypair = utils::generate_keypair(); - self.global.insert(b"keypair", &keypair)?; - Ok::<_, Error>(keypair) - }, - |val| Ok(val.to_vec()), - )?; + let generate = |_| { + let keypair = Ed25519KeyPair::generate().expect("Ed25519KeyPair generation always works (?)"); - let mut parts = keypair_bytes.splitn(2, |&b| b == 0xFF); + let mut value = rand::string(8).as_bytes().to_vec(); + value.push(0xFF); + value.extend_from_slice(&keypair); + + self.global.insert(b"keypair", &value); + value + }; + + let keypair_bytes: Vec = self + .global + .get_blocking(b"keypair") + .map_or_else(generate, Into::into); + let mut parts = keypair_bytes.splitn(2, |&b| b == 0xFF); utils::string_from_bytes( // 1. version parts @@ -241,7 +244,10 @@ impl Data { } #[inline] - pub fn remove_keypair(&self) -> Result<()> { self.global.remove(b"keypair") } + pub fn remove_keypair(&self) -> Result<()> { + self.global.remove(b"keypair"); + Ok(()) + } /// TODO: the key valid until timestamp (`valid_until_ts`) is only honored /// in room version > 4 @@ -250,15 +256,15 @@ impl Data { /// /// This doesn't actually check that the keys provided are newer than the /// old set. - pub fn add_signing_key( + pub async fn add_signing_key( &self, origin: &ServerName, new_keys: ServerSigningKeys, - ) -> Result> { + ) -> BTreeMap { // Not atomic, but this is not critical - let signingkeys = self.server_signingkeys.get(origin.as_bytes())?; + let signingkeys = self.server_signingkeys.get(origin).await; let mut keys = signingkeys - .and_then(|keys| serde_json::from_slice(&keys).ok()) - .unwrap_or_else(|| { + .and_then(|keys| serde_json::from_slice(&keys).map_err(Into::into)) + .unwrap_or_else(|_| { // Just insert "now", it doesn't matter ServerSigningKeys::new(origin.to_owned(), MilliSecondsSinceUnixEpoch::now()) }); @@ -275,7 +281,7 @@ impl Data { self.server_signingkeys.insert( origin.as_bytes(), &serde_json::to_vec(&keys).expect("serversigningkeys can be serialized"), - )?; + ); let mut tree = keys.verify_keys; tree.extend( @@ -284,45 +290,41 @@ impl Data { .map(|old| (old.0, VerifyKey::new(old.1.key))), ); - Ok(tree) + tree } /// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found /// for the server. - pub fn verify_keys_for(&self, origin: &ServerName) -> Result> { - let signingkeys = self - .signing_keys_for(origin)? - .map_or_else(BTreeMap::new, |keys: ServerSigningKeys| { + pub async fn verify_keys_for(&self, origin: &ServerName) -> Result> { + self.signing_keys_for(origin).await.map_or_else( + |_| Ok(BTreeMap::new()), + |keys: ServerSigningKeys| { let mut tree = keys.verify_keys; tree.extend( keys.old_verify_keys .into_iter() .map(|old| (old.0, VerifyKey::new(old.1.key))), ); - tree - }); - - Ok(signingkeys) + Ok(tree) + }, + ) } - pub fn signing_keys_for(&self, origin: &ServerName) -> Result> { - let signingkeys = self - .server_signingkeys - .get(origin.as_bytes())? - .and_then(|bytes| serde_json::from_slice(&bytes).ok()); - - Ok(signingkeys) + pub async fn signing_keys_for(&self, origin: &ServerName) -> Result { + self.server_signingkeys.get(origin).await.deserialized() } - pub fn database_version(&self) -> Result { - self.global.get(b"version")?.map_or(Ok(0), |version| { - utils::u64_from_bytes(&version).map_err(|_| Error::bad_database("Database version id is invalid.")) - }) + pub async fn database_version(&self) -> u64 { + self.global + .get(b"version") + .await + .deserialized() + .unwrap_or(0) } #[inline] pub fn bump_database_version(&self, new_version: u64) -> Result<()> { - self.global.insert(b"version", &new_version.to_be_bytes())?; + self.global.insert(b"version", &new_version.to_be_bytes()); Ok(()) } diff --git a/src/service/globals/migrations.rs b/src/service/globals/migrations.rs index 66917520b..fc6e477b3 100644 --- a/src/service/globals/migrations.rs +++ b/src/service/globals/migrations.rs @@ -1,17 +1,15 @@ -use std::{ - collections::{HashMap, HashSet}, - fs::{self}, - io::Write, - mem::size_of, - sync::Arc, +use conduit::{ + debug_info, debug_warn, error, info, + result::NotFound, + utils::{stream::TryIgnore, IterStream, ReadyExt}, + warn, Err, Error, Result, }; - -use conduit::{debug, debug_info, debug_warn, error, info, utils, warn, Error, Result}; +use futures::{FutureExt, StreamExt}; use itertools::Itertools; use ruma::{ events::{push_rules::PushRulesEvent, room::member::MembershipState, GlobalAccountDataEventType}, push::Ruleset, - EventId, OwnedRoomId, RoomId, UserId, + OwnedUserId, UserId, }; use crate::{media, Services}; @@ -33,12 +31,14 @@ pub(crate) const DATABASE_VERSION: u64 = 13; pub(crate) const CONDUIT_DATABASE_VERSION: u64 = 16; pub(crate) async fn migrations(services: &Services) -> Result<()> { + let users_count = services.users.count().await; + // Matrix resource ownership is based on the server name; changing it // requires recreating the database from scratch. - if services.users.count()? > 0 { + if users_count > 0 { let conduit_user = &services.globals.server_user; - if !services.users.exists(conduit_user)? { + if !services.users.exists(conduit_user).await { error!("The {} server user does not exist, and the database is not new.", conduit_user); return Err(Error::bad_database( "Cannot reuse an existing database after changing the server name, please delete the old one first.", @@ -46,7 +46,7 @@ pub(crate) async fn migrations(services: &Services) -> Result<()> { } } - if services.users.count()? > 0 { + if users_count > 0 { migrate(services).await } else { fresh(services).await @@ -62,9 +62,9 @@ async fn fresh(services: &Services) -> Result<()> { .db .bump_database_version(DATABASE_VERSION)?; - db["global"].insert(b"feat_sha256_media", &[])?; - db["global"].insert(b"fix_bad_double_separator_in_state_cache", &[])?; - db["global"].insert(b"retroactively_fix_bad_data_from_roomuserid_joined", &[])?; + db["global"].insert(b"feat_sha256_media", &[]); + db["global"].insert(b"fix_bad_double_separator_in_state_cache", &[]); + db["global"].insert(b"retroactively_fix_bad_data_from_roomuserid_joined", &[]); // Create the admin room and server user on first run crate::admin::create_admin_room(services).await?; @@ -82,566 +82,132 @@ async fn migrate(services: &Services) -> Result<()> { let db = &services.db; let config = &services.server.config; - if services.globals.db.database_version()? < 1 { - db_lt_1(services).await?; - } - - if services.globals.db.database_version()? < 2 { - db_lt_2(services).await?; + if services.globals.db.database_version().await < 11 { + return Err!(Database( + "Database schema version {} is no longer supported", + services.globals.db.database_version().await + )); } - if services.globals.db.database_version()? < 3 { - db_lt_3(services).await?; - } - - if services.globals.db.database_version()? < 4 { - db_lt_4(services).await?; - } - - if services.globals.db.database_version()? < 5 { - db_lt_5(services).await?; - } - - if services.globals.db.database_version()? < 6 { - db_lt_6(services).await?; - } - - if services.globals.db.database_version()? < 7 { - db_lt_7(services).await?; - } - - if services.globals.db.database_version()? < 8 { - db_lt_8(services).await?; - } - - if services.globals.db.database_version()? < 9 { - db_lt_9(services).await?; - } - - if services.globals.db.database_version()? < 10 { - db_lt_10(services).await?; - } - - if services.globals.db.database_version()? < 11 { - db_lt_11(services).await?; - } - - if services.globals.db.database_version()? < 12 { + if services.globals.db.database_version().await < 12 { db_lt_12(services).await?; } // This migration can be reused as-is anytime the server-default rules are // updated. - if services.globals.db.database_version()? < 13 { + if services.globals.db.database_version().await < 13 { db_lt_13(services).await?; } - if db["global"].get(b"feat_sha256_media")?.is_none() { + if db["global"].get(b"feat_sha256_media").await.is_not_found() { media::migrations::migrate_sha256_media(services).await?; } else if config.media_startup_check { media::migrations::checkup_sha256_media(services).await?; } if db["global"] - .get(b"fix_bad_double_separator_in_state_cache")? - .is_none() + .get(b"fix_bad_double_separator_in_state_cache") + .await + .is_not_found() { fix_bad_double_separator_in_state_cache(services).await?; } if db["global"] - .get(b"retroactively_fix_bad_data_from_roomuserid_joined")? - .is_none() + .get(b"retroactively_fix_bad_data_from_roomuserid_joined") + .await + .is_not_found() { retroactively_fix_bad_data_from_roomuserid_joined(services).await?; } - let version_match = services.globals.db.database_version().unwrap() == DATABASE_VERSION - || services.globals.db.database_version().unwrap() == CONDUIT_DATABASE_VERSION; + let version_match = services.globals.db.database_version().await == DATABASE_VERSION + || services.globals.db.database_version().await == CONDUIT_DATABASE_VERSION; assert!( version_match, "Failed asserting local database version {} is equal to known latest conduwuit database version {}", - services.globals.db.database_version().unwrap(), + services.globals.db.database_version().await, DATABASE_VERSION, ); { let patterns = services.globals.forbidden_usernames(); if !patterns.is_empty() { - for user_id in services + services .users - .iter() - .filter_map(Result::ok) - .filter(|user| !services.users.is_deactivated(user).unwrap_or(true)) - .filter(|user| user.server_name() == config.server_name) - { - let matches = patterns.matches(user_id.localpart()); - if matches.matched_any() { - warn!( - "User {} matches the following forbidden username patterns: {}", - user_id.to_string(), - matches - .into_iter() - .map(|x| &patterns.patterns()[x]) - .join(", ") - ); - } - } - } - } - - { - let patterns = services.globals.forbidden_alias_names(); - if !patterns.is_empty() { - for address in services.rooms.metadata.iter_ids() { - let room_id = address?; - let room_aliases = services.rooms.alias.local_aliases_for_room(&room_id); - for room_alias_result in room_aliases { - let room_alias = room_alias_result?; - let matches = patterns.matches(room_alias.alias()); + .stream() + .filter(|user_id| services.users.is_active_local(user_id)) + .ready_for_each(|user_id| { + let matches = patterns.matches(user_id.localpart()); if matches.matched_any() { warn!( - "Room with alias {} ({}) matches the following forbidden room name patterns: {}", - room_alias, - &room_id, + "User {} matches the following forbidden username patterns: {}", + user_id.to_string(), matches .into_iter() .map(|x| &patterns.patterns()[x]) .join(", ") ); } - } - } - } - } - - info!( - "Loaded {} database with schema version {DATABASE_VERSION}", - config.database_backend, - ); - - Ok(()) -} - -async fn db_lt_1(services: &Services) -> Result<()> { - let db = &services.db; - - let roomserverids = &db["roomserverids"]; - let serverroomids = &db["serverroomids"]; - for (roomserverid, _) in roomserverids.iter() { - let mut parts = roomserverid.split(|&b| b == 0xFF); - let room_id = parts.next().expect("split always returns one element"); - let Some(servername) = parts.next() else { - error!("Migration: Invalid roomserverid in db."); - continue; - }; - let mut serverroomid = servername.to_vec(); - serverroomid.push(0xFF); - serverroomid.extend_from_slice(room_id); - - serverroomids.insert(&serverroomid, &[])?; - } - - services.globals.db.bump_database_version(1)?; - info!("Migration: 0 -> 1 finished"); - Ok(()) -} - -async fn db_lt_2(services: &Services) -> Result<()> { - let db = &services.db; - - // We accidentally inserted hashed versions of "" into the db instead of just "" - let userid_password = &db["roomserverids"]; - for (userid, password) in userid_password.iter() { - let empty_pass = utils::hash::password("").expect("our own password to be properly hashed"); - let password = std::str::from_utf8(&password).expect("password is valid utf-8"); - let empty_hashed_password = utils::hash::verify_password(password, &empty_pass).is_ok(); - if empty_hashed_password { - userid_password.insert(&userid, b"")?; + }) + .await; } } - services.globals.db.bump_database_version(2)?; - info!("Migration: 1 -> 2 finished"); - Ok(()) -} - -async fn db_lt_3(services: &Services) -> Result<()> { - let db = &services.db; - - // Move media to filesystem - let mediaid_file = &db["mediaid_file"]; - for (key, content) in mediaid_file.iter() { - if content.is_empty() { - continue; - } - - #[allow(deprecated)] - let path = services.media.get_media_file(&key); - let mut file = fs::File::create(path)?; - file.write_all(&content)?; - mediaid_file.insert(&key, &[])?; - } - - services.globals.db.bump_database_version(3)?; - info!("Migration: 2 -> 3 finished"); - Ok(()) -} - -async fn db_lt_4(services: &Services) -> Result<()> { - let config = &services.server.config; - - // Add federated users to services as deactivated - for our_user in services.users.iter() { - let our_user = our_user?; - if services.users.is_deactivated(&our_user)? { - continue; - } - for room in services.rooms.state_cache.rooms_joined(&our_user) { - for user in services.rooms.state_cache.room_members(&room?) { - let user = user?; - if user.server_name() != config.server_name { - info!(?user, "Migration: creating user"); - services.users.create(&user, None)?; - } - } - } - } - - services.globals.db.bump_database_version(4)?; - info!("Migration: 3 -> 4 finished"); - Ok(()) -} - -async fn db_lt_5(services: &Services) -> Result<()> { - let db = &services.db; - - // Upgrade user data store - let roomuserdataid_accountdata = &db["roomuserdataid_accountdata"]; - let roomusertype_roomuserdataid = &db["roomusertype_roomuserdataid"]; - for (roomuserdataid, _) in roomuserdataid_accountdata.iter() { - let mut parts = roomuserdataid.split(|&b| b == 0xFF); - let room_id = parts.next().unwrap(); - let user_id = parts.next().unwrap(); - let event_type = roomuserdataid.rsplit(|&b| b == 0xFF).next().unwrap(); - - let mut key = room_id.to_vec(); - key.push(0xFF); - key.extend_from_slice(user_id); - key.push(0xFF); - key.extend_from_slice(event_type); - - roomusertype_roomuserdataid.insert(&key, &roomuserdataid)?; - } - - services.globals.db.bump_database_version(5)?; - info!("Migration: 4 -> 5 finished"); - Ok(()) -} - -async fn db_lt_6(services: &Services) -> Result<()> { - let db = &services.db; - - // Set room member count - let roomid_shortstatehash = &db["roomid_shortstatehash"]; - for (roomid, _) in roomid_shortstatehash.iter() { - let string = utils::string_from_bytes(&roomid).unwrap(); - let room_id = <&RoomId>::try_from(string.as_str()).unwrap(); - services.rooms.state_cache.update_joined_count(room_id)?; - } - - services.globals.db.bump_database_version(6)?; - info!("Migration: 5 -> 6 finished"); - Ok(()) -} - -async fn db_lt_7(services: &Services) -> Result<()> { - let db = &services.db; - - // Upgrade state store - let mut last_roomstates: HashMap = HashMap::new(); - let mut current_sstatehash: Option = None; - let mut current_room = None; - let mut current_state = HashSet::new(); - - let handle_state = |current_sstatehash: u64, - current_room: &RoomId, - current_state: HashSet<_>, - last_roomstates: &mut HashMap<_, _>| { - let last_roomsstatehash = last_roomstates.get(current_room); - - let states_parents = last_roomsstatehash.map_or_else( - || Ok(Vec::new()), - |&last_roomsstatehash| { + { + let patterns = services.globals.forbidden_alias_names(); + if !patterns.is_empty() { + for room_id in services + .rooms + .metadata + .iter_ids() + .map(ToOwned::to_owned) + .collect::>() + .await + { services .rooms - .state_compressor - .load_shortstatehash_info(last_roomsstatehash) - }, - )?; - - let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() { - let statediffnew = current_state - .difference(&parent_stateinfo.1) - .copied() - .collect::>(); - - let statediffremoved = parent_stateinfo - .1 - .difference(¤t_state) - .copied() - .collect::>(); - - (statediffnew, statediffremoved) - } else { - (current_state, HashSet::new()) - }; - - services.rooms.state_compressor.save_state_from_diff( - current_sstatehash, - Arc::new(statediffnew), - Arc::new(statediffremoved), - 2, // every state change is 2 event changes on average - states_parents, - )?; - - /* - let mut tmp = services.rooms.load_shortstatehash_info(¤t_sstatehash)?; - let state = tmp.pop().unwrap(); - println!( - "{}\t{}{:?}: {:?} + {:?} - {:?}", - current_room, - " ".repeat(tmp.len()), - utils::u64_from_bytes(¤t_sstatehash).unwrap(), - tmp.last().map(|b| utils::u64_from_bytes(&b.0).unwrap()), - state - .2 - .iter() - .map(|b| utils::u64_from_bytes(&b[size_of::()..]).unwrap()) - .collect::>(), - state - .3 - .iter() - .map(|b| utils::u64_from_bytes(&b[size_of::()..]).unwrap()) - .collect::>() - ); - */ - - Ok::<_, Error>(()) - }; - - let stateid_shorteventid = &db["stateid_shorteventid"]; - let shorteventid_eventid = &db["shorteventid_eventid"]; - for (k, seventid) in stateid_shorteventid.iter() { - let sstatehash = utils::u64_from_bytes(&k[0..size_of::()]).expect("number of bytes is correct"); - let sstatekey = k[size_of::()..].to_vec(); - if Some(sstatehash) != current_sstatehash { - if let Some(current_sstatehash) = current_sstatehash { - handle_state( - current_sstatehash, - current_room.as_deref().unwrap(), - current_state, - &mut last_roomstates, - )?; - last_roomstates.insert(current_room.clone().unwrap(), current_sstatehash); - } - current_state = HashSet::new(); - current_sstatehash = Some(sstatehash); - - let event_id = shorteventid_eventid.get(&seventid).unwrap().unwrap(); - let string = utils::string_from_bytes(&event_id).unwrap(); - let event_id = <&EventId>::try_from(string.as_str()).unwrap(); - let pdu = services.rooms.timeline.get_pdu(event_id).unwrap().unwrap(); - - if Some(&pdu.room_id) != current_room.as_ref() { - current_room = Some(pdu.room_id.clone()); + .alias + .local_aliases_for_room(&room_id) + .ready_for_each(|room_alias| { + let matches = patterns.matches(room_alias.alias()); + if matches.matched_any() { + warn!( + "Room with alias {} ({}) matches the following forbidden room name patterns: {}", + room_alias, + &room_id, + matches + .into_iter() + .map(|x| &patterns.patterns()[x]) + .join(", ") + ); + } + }) + .await; } } - - let mut val = sstatekey; - val.extend_from_slice(&seventid); - current_state.insert(val.try_into().expect("size is correct")); - } - - if let Some(current_sstatehash) = current_sstatehash { - handle_state( - current_sstatehash, - current_room.as_deref().unwrap(), - current_state, - &mut last_roomstates, - )?; - } - - services.globals.db.bump_database_version(7)?; - info!("Migration: 6 -> 7 finished"); - Ok(()) -} - -async fn db_lt_8(services: &Services) -> Result<()> { - let db = &services.db; - - let roomid_shortstatehash = &db["roomid_shortstatehash"]; - let roomid_shortroomid = &db["roomid_shortroomid"]; - let pduid_pdu = &db["pduid_pdu"]; - let eventid_pduid = &db["eventid_pduid"]; - - // Generate short room ids for all rooms - for (room_id, _) in roomid_shortstatehash.iter() { - let shortroomid = services.globals.next_count()?.to_be_bytes(); - roomid_shortroomid.insert(&room_id, &shortroomid)?; - info!("Migration: 8"); - } - // Update pduids db layout - let batch = pduid_pdu - .iter() - .filter_map(|(key, v)| { - if !key.starts_with(b"!") { - return None; - } - let mut parts = key.splitn(2, |&b| b == 0xFF); - let room_id = parts.next().unwrap(); - let count = parts.next().unwrap(); - - let short_room_id = roomid_shortroomid - .get(room_id) - .unwrap() - .expect("shortroomid should exist"); - - let mut new_key = short_room_id.to_vec(); - new_key.extend_from_slice(count); - - Some(database::OwnedKeyVal(new_key, v)) - }) - .collect::>(); - - pduid_pdu.insert_batch(batch.iter().map(database::KeyVal::from))?; - - let batch2 = eventid_pduid - .iter() - .filter_map(|(k, value)| { - if !value.starts_with(b"!") { - return None; - } - let mut parts = value.splitn(2, |&b| b == 0xFF); - let room_id = parts.next().unwrap(); - let count = parts.next().unwrap(); - - let short_room_id = roomid_shortroomid - .get(room_id) - .unwrap() - .expect("shortroomid should exist"); - - let mut new_value = short_room_id.to_vec(); - new_value.extend_from_slice(count); - - Some(database::OwnedKeyVal(k, new_value)) - }) - .collect::>(); - - eventid_pduid.insert_batch(batch2.iter().map(database::KeyVal::from))?; - - services.globals.db.bump_database_version(8)?; - info!("Migration: 7 -> 8 finished"); - Ok(()) -} - -async fn db_lt_9(services: &Services) -> Result<()> { - let db = &services.db; - - let tokenids = &db["tokenids"]; - let roomid_shortroomid = &db["roomid_shortroomid"]; - - // Update tokenids db layout - let mut iter = tokenids - .iter() - .filter_map(|(key, _)| { - if !key.starts_with(b"!") { - return None; - } - let mut parts = key.splitn(4, |&b| b == 0xFF); - let room_id = parts.next().unwrap(); - let word = parts.next().unwrap(); - let _pdu_id_room = parts.next().unwrap(); - let pdu_id_count = parts.next().unwrap(); - - let short_room_id = roomid_shortroomid - .get(room_id) - .unwrap() - .expect("shortroomid should exist"); - let mut new_key = short_room_id.to_vec(); - new_key.extend_from_slice(word); - new_key.push(0xFF); - new_key.extend_from_slice(pdu_id_count); - Some(database::OwnedKeyVal(new_key, Vec::::new())) - }) - .peekable(); - - while iter.peek().is_some() { - let batch = iter.by_ref().take(1000).collect::>(); - tokenids.insert_batch(batch.iter().map(database::KeyVal::from))?; - debug!("Inserted smaller batch"); } - info!("Deleting starts"); - - let batch2: Vec<_> = tokenids - .iter() - .filter_map(|(key, _)| { - if key.starts_with(b"!") { - Some(key) - } else { - None - } - }) - .collect(); - - for key in batch2 { - tokenids.remove(&key)?; - } - - services.globals.db.bump_database_version(9)?; - info!("Migration: 8 -> 9 finished"); - Ok(()) -} - -async fn db_lt_10(services: &Services) -> Result<()> { - let db = &services.db; - - let statekey_shortstatekey = &db["statekey_shortstatekey"]; - let shortstatekey_statekey = &db["shortstatekey_statekey"]; - - // Add other direction for shortstatekeys - for (statekey, shortstatekey) in statekey_shortstatekey.iter() { - shortstatekey_statekey.insert(&shortstatekey, &statekey)?; - } - - // Force E2EE device list updates so we can send them over federation - for user_id in services.users.iter().filter_map(Result::ok) { - services.users.mark_device_key_update(&user_id)?; - } - - services.globals.db.bump_database_version(10)?; - info!("Migration: 9 -> 10 finished"); - Ok(()) -} - -#[allow(unreachable_code)] -async fn db_lt_11(services: &Services) -> Result<()> { - error!("Dropping a column to clear data is not implemented yet."); - //let userdevicesessionid_uiaarequest = &db["userdevicesessionid_uiaarequest"]; - //userdevicesessionid_uiaarequest.clear()?; + info!( + "Loaded {} database with schema version {DATABASE_VERSION}", + config.database_backend, + ); - services.globals.db.bump_database_version(11)?; - info!("Migration: 10 -> 11 finished"); Ok(()) } async fn db_lt_12(services: &Services) -> Result<()> { let config = &services.server.config; - for username in services.users.list_local_users()? { - let user = match UserId::parse_with_server_name(username.clone(), &config.server_name) { + for username in &services + .users + .list_local_users() + .map(UserId::to_owned) + .collect::>() + .await + { + let user = match UserId::parse_with_server_name(username.as_str(), &config.server_name) { Ok(u) => u, Err(e) => { warn!("Invalid username {username}: {e}"); @@ -652,7 +218,7 @@ async fn db_lt_12(services: &Services) -> Result<()> { let raw_rules_list = services .account_data .get(None, &user, GlobalAccountDataEventType::PushRules.to_string().into()) - .unwrap() + .await .expect("Username is invalid"); let mut account_data = serde_json::from_str::(raw_rules_list.get()).unwrap(); @@ -694,12 +260,15 @@ async fn db_lt_12(services: &Services) -> Result<()> { } } - services.account_data.update( - None, - &user, - GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(account_data).expect("to json value always works"), - )?; + services + .account_data + .update( + None, + &user, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(account_data).expect("to json value always works"), + ) + .await?; } services.globals.db.bump_database_version(12)?; @@ -710,8 +279,14 @@ async fn db_lt_12(services: &Services) -> Result<()> { async fn db_lt_13(services: &Services) -> Result<()> { let config = &services.server.config; - for username in services.users.list_local_users()? { - let user = match UserId::parse_with_server_name(username.clone(), &config.server_name) { + for username in &services + .users + .list_local_users() + .map(UserId::to_owned) + .collect::>() + .await + { + let user = match UserId::parse_with_server_name(username.as_str(), &config.server_name) { Ok(u) => u, Err(e) => { warn!("Invalid username {username}: {e}"); @@ -722,7 +297,7 @@ async fn db_lt_13(services: &Services) -> Result<()> { let raw_rules_list = services .account_data .get(None, &user, GlobalAccountDataEventType::PushRules.to_string().into()) - .unwrap() + .await .expect("Username is invalid"); let mut account_data = serde_json::from_str::(raw_rules_list.get()).unwrap(); @@ -733,12 +308,15 @@ async fn db_lt_13(services: &Services) -> Result<()> { .global .update_with_server_default(user_default_rules); - services.account_data.update( - None, - &user, - GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(account_data).expect("to json value always works"), - )?; + services + .account_data + .update( + None, + &user, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(account_data).expect("to json value always works"), + ) + .await?; } services.globals.db.bump_database_version(13)?; @@ -754,32 +332,37 @@ async fn fix_bad_double_separator_in_state_cache(services: &Services) -> Result< let _cork = db.cork_and_sync(); let mut iter_count: usize = 0; - for (mut key, value) in roomuserid_joined.iter() { - iter_count = iter_count.saturating_add(1); - debug_info!(%iter_count); - let first_sep_index = key - .iter() - .position(|&i| i == 0xFF) - .expect("found 0xFF delim"); + roomuserid_joined + .raw_stream() + .ignore_err() + .ready_for_each(|(key, value)| { + let mut key = key.to_vec(); + iter_count = iter_count.saturating_add(1); + debug_info!(%iter_count); + let first_sep_index = key + .iter() + .position(|&i| i == 0xFF) + .expect("found 0xFF delim"); - if key - .iter() - .get(first_sep_index..=first_sep_index.saturating_add(1)) - .copied() - .collect_vec() - == vec![0xFF, 0xFF] - { - debug_warn!("Found bad key: {key:?}"); - roomuserid_joined.remove(&key)?; + if key + .iter() + .get(first_sep_index..=first_sep_index.saturating_add(1)) + .copied() + .collect_vec() + == vec![0xFF, 0xFF] + { + debug_warn!("Found bad key: {key:?}"); + roomuserid_joined.remove(&key); - key.remove(first_sep_index); - debug_warn!("Fixed key: {key:?}"); - roomuserid_joined.insert(&key, &value)?; - } - } + key.remove(first_sep_index); + debug_warn!("Fixed key: {key:?}"); + roomuserid_joined.insert(&key, value); + } + }) + .await; db.db.cleanup()?; - db["global"].insert(b"fix_bad_double_separator_in_state_cache", &[])?; + db["global"].insert(b"fix_bad_double_separator_in_state_cache", &[]); info!("Finished fixing"); Ok(()) @@ -795,69 +378,72 @@ async fn retroactively_fix_bad_data_from_roomuserid_joined(services: &Services) .rooms .metadata .iter_ids() - .filter_map(Result::ok) - .collect_vec(); + .map(ToOwned::to_owned) + .collect::>() + .await; - for room_id in room_ids.clone() { + for room_id in &room_ids { debug_info!("Fixing room {room_id}"); - let users_in_room = services + let users_in_room: Vec = services .rooms .state_cache - .room_members(&room_id) - .filter_map(Result::ok) - .collect_vec(); + .room_members(room_id) + .map(ToOwned::to_owned) + .collect() + .await; let joined_members = users_in_room .iter() + .stream() .filter(|user_id| { services .rooms .state_accessor - .get_member(&room_id, user_id) - .unwrap_or(None) - .map_or(false, |membership| membership.membership == MembershipState::Join) + .get_member(room_id, user_id) + .map(|member| member.map_or(false, |member| member.membership == MembershipState::Join)) }) - .collect_vec(); + .collect::>() + .await; let non_joined_members = users_in_room .iter() + .stream() .filter(|user_id| { services .rooms .state_accessor - .get_member(&room_id, user_id) - .unwrap_or(None) - .map_or(false, |membership| { - membership.membership == MembershipState::Leave || membership.membership == MembershipState::Ban - }) + .get_member(room_id, user_id) + .map(|member| member.map_or(false, |member| member.membership == MembershipState::Join)) }) - .collect_vec(); + .collect::>() + .await; - for user_id in joined_members { + for user_id in &joined_members { debug_info!("User is joined, marking as joined"); - services - .rooms - .state_cache - .mark_as_joined(user_id, &room_id)?; + services.rooms.state_cache.mark_as_joined(user_id, room_id); } - for user_id in non_joined_members { + for user_id in &non_joined_members { debug_info!("User is left or banned, marking as left"); - services.rooms.state_cache.mark_as_left(user_id, &room_id)?; + services.rooms.state_cache.mark_as_left(user_id, room_id); } } - for room_id in room_ids { + for room_id in &room_ids { debug_info!( "Updating joined count for room {room_id} to fix servers in room after correcting membership states" ); - services.rooms.state_cache.update_joined_count(&room_id)?; + services + .rooms + .state_cache + .update_joined_count(room_id) + .await; } db.db.cleanup()?; - db["global"].insert(b"retroactively_fix_bad_data_from_roomuserid_joined", &[])?; + db["global"].insert(b"retroactively_fix_bad_data_from_roomuserid_joined", &[]); info!("Finished fixing"); Ok(()) diff --git a/src/service/globals/mod.rs b/src/service/globals/mod.rs index 87f8f4925..f24e8a274 100644 --- a/src/service/globals/mod.rs +++ b/src/service/globals/mod.rs @@ -288,8 +288,8 @@ impl Service { /// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found /// for the server. - pub fn verify_keys_for(&self, origin: &ServerName) -> Result> { - let mut keys = self.db.verify_keys_for(origin)?; + pub async fn verify_keys_for(&self, origin: &ServerName) -> Result> { + let mut keys = self.db.verify_keys_for(origin).await?; if origin == self.server_name() { keys.insert( format!("ed25519:{}", self.keypair().version()) @@ -304,8 +304,8 @@ impl Service { Ok(keys) } - pub fn signing_keys_for(&self, origin: &ServerName) -> Result> { - self.db.signing_keys_for(origin) + pub async fn signing_keys_for(&self, origin: &ServerName) -> Result { + self.db.signing_keys_for(origin).await } pub fn well_known_client(&self) -> &Option { &self.config.well_known.client } @@ -329,4 +329,7 @@ impl Service { #[inline] pub fn server_is_ours(&self, server_name: &ServerName) -> bool { server_name == self.config.server_name } + + #[inline] + pub fn is_read_only(&self) -> bool { self.db.db.is_read_only() } } diff --git a/src/service/key_backups/data.rs b/src/service/key_backups/data.rs deleted file mode 100644 index 30ac593b1..000000000 --- a/src/service/key_backups/data.rs +++ /dev/null @@ -1,346 +0,0 @@ -use std::{collections::BTreeMap, sync::Arc}; - -use conduit::{utils, Error, Result}; -use database::Map; -use ruma::{ - api::client::{ - backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup}, - error::ErrorKind, - }, - serde::Raw, - OwnedRoomId, RoomId, UserId, -}; - -use crate::{globals, Dep}; - -pub(super) struct Data { - backupid_algorithm: Arc, - backupid_etag: Arc, - backupkeyid_backup: Arc, - services: Services, -} - -struct Services { - globals: Dep, -} - -impl Data { - pub(super) fn new(args: &crate::Args<'_>) -> Self { - let db = &args.db; - Self { - backupid_algorithm: db["backupid_algorithm"].clone(), - backupid_etag: db["backupid_etag"].clone(), - backupkeyid_backup: db["backupkeyid_backup"].clone(), - services: Services { - globals: args.depend::("globals"), - }, - } - } - - pub(super) fn create_backup(&self, user_id: &UserId, backup_metadata: &Raw) -> Result { - let version = self.services.globals.next_count()?.to_string(); - - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(version.as_bytes()); - - self.backupid_algorithm.insert( - &key, - &serde_json::to_vec(backup_metadata).expect("BackupAlgorithm::to_vec always works"), - )?; - self.backupid_etag - .insert(&key, &self.services.globals.next_count()?.to_be_bytes())?; - Ok(version) - } - - pub(super) fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(version.as_bytes()); - - self.backupid_algorithm.remove(&key)?; - self.backupid_etag.remove(&key)?; - - key.push(0xFF); - - for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) { - self.backupkeyid_backup.remove(&outdated_key)?; - } - - Ok(()) - } - - pub(super) fn update_backup( - &self, user_id: &UserId, version: &str, backup_metadata: &Raw, - ) -> Result { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(version.as_bytes()); - - if self.backupid_algorithm.get(&key)?.is_none() { - return Err(Error::BadRequest(ErrorKind::NotFound, "Tried to update nonexistent backup.")); - } - - self.backupid_algorithm - .insert(&key, backup_metadata.json().get().as_bytes())?; - self.backupid_etag - .insert(&key, &self.services.globals.next_count()?.to_be_bytes())?; - Ok(version.to_owned()) - } - - pub(super) fn get_latest_backup_version(&self, user_id: &UserId) -> Result> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - let mut last_possible_key = prefix.clone(); - last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes()); - - self.backupid_algorithm - .iter_from(&last_possible_key, true) - .take_while(move |(k, _)| k.starts_with(&prefix)) - .next() - .map(|(key, _)| { - utils::string_from_bytes( - key.rsplit(|&b| b == 0xFF) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| Error::bad_database("backupid_algorithm key is invalid.")) - }) - .transpose() - } - - pub(super) fn get_latest_backup(&self, user_id: &UserId) -> Result)>> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - let mut last_possible_key = prefix.clone(); - last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes()); - - self.backupid_algorithm - .iter_from(&last_possible_key, true) - .take_while(move |(k, _)| k.starts_with(&prefix)) - .next() - .map(|(key, value)| { - let version = utils::string_from_bytes( - key.rsplit(|&b| b == 0xFF) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| Error::bad_database("backupid_algorithm key is invalid."))?; - - Ok(( - version, - serde_json::from_slice(&value) - .map_err(|_| Error::bad_database("Algorithm in backupid_algorithm is invalid."))?, - )) - }) - .transpose() - } - - pub(super) fn get_backup(&self, user_id: &UserId, version: &str) -> Result>> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(version.as_bytes()); - - self.backupid_algorithm - .get(&key)? - .map_or(Ok(None), |bytes| { - serde_json::from_slice(&bytes) - .map_err(|_| Error::bad_database("Algorithm in backupid_algorithm is invalid.")) - }) - } - - pub(super) fn add_key( - &self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, key_data: &Raw, - ) -> Result<()> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(version.as_bytes()); - - if self.backupid_algorithm.get(&key)?.is_none() { - return Err(Error::BadRequest(ErrorKind::NotFound, "Tried to update nonexistent backup.")); - } - - self.backupid_etag - .insert(&key, &self.services.globals.next_count()?.to_be_bytes())?; - - key.push(0xFF); - key.extend_from_slice(room_id.as_bytes()); - key.push(0xFF); - key.extend_from_slice(session_id.as_bytes()); - - self.backupkeyid_backup - .insert(&key, key_data.json().get().as_bytes())?; - - Ok(()) - } - - pub(super) fn count_keys(&self, user_id: &UserId, version: &str) -> Result { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - prefix.extend_from_slice(version.as_bytes()); - - Ok(self.backupkeyid_backup.scan_prefix(prefix).count()) - } - - pub(super) fn get_etag(&self, user_id: &UserId, version: &str) -> Result { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(version.as_bytes()); - - Ok(utils::u64_from_bytes( - &self - .backupid_etag - .get(&key)? - .ok_or_else(|| Error::bad_database("Backup has no etag."))?, - ) - .map_err(|_| Error::bad_database("etag in backupid_etag invalid."))? - .to_string()) - } - - pub(super) fn get_all(&self, user_id: &UserId, version: &str) -> Result> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - prefix.extend_from_slice(version.as_bytes()); - prefix.push(0xFF); - - let mut rooms = BTreeMap::::new(); - - for result in self - .backupkeyid_backup - .scan_prefix(prefix) - .map(|(key, value)| { - let mut parts = key.rsplit(|&b| b == 0xFF); - - let session_id = utils::string_from_bytes( - parts - .next() - .ok_or_else(|| Error::bad_database("backupkeyid_backup key is invalid."))?, - ) - .map_err(|_| Error::bad_database("backupkeyid_backup session_id is invalid."))?; - - let room_id = RoomId::parse( - utils::string_from_bytes( - parts - .next() - .ok_or_else(|| Error::bad_database("backupkeyid_backup key is invalid."))?, - ) - .map_err(|_| Error::bad_database("backupkeyid_backup room_id is invalid."))?, - ) - .map_err(|_| Error::bad_database("backupkeyid_backup room_id is invalid room id."))?; - - let key_data = serde_json::from_slice(&value) - .map_err(|_| Error::bad_database("KeyBackupData in backupkeyid_backup is invalid."))?; - - Ok::<_, Error>((room_id, session_id, key_data)) - }) { - let (room_id, session_id, key_data) = result?; - rooms - .entry(room_id) - .or_insert_with(|| RoomKeyBackup { - sessions: BTreeMap::new(), - }) - .sessions - .insert(session_id, key_data); - } - - Ok(rooms) - } - - pub(super) fn get_room( - &self, user_id: &UserId, version: &str, room_id: &RoomId, - ) -> Result>> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - prefix.extend_from_slice(version.as_bytes()); - prefix.push(0xFF); - prefix.extend_from_slice(room_id.as_bytes()); - prefix.push(0xFF); - - Ok(self - .backupkeyid_backup - .scan_prefix(prefix) - .map(|(key, value)| { - let mut parts = key.rsplit(|&b| b == 0xFF); - - let session_id = utils::string_from_bytes( - parts - .next() - .ok_or_else(|| Error::bad_database("backupkeyid_backup key is invalid."))?, - ) - .map_err(|_| Error::bad_database("backupkeyid_backup session_id is invalid."))?; - - let key_data = serde_json::from_slice(&value) - .map_err(|_| Error::bad_database("KeyBackupData in backupkeyid_backup is invalid."))?; - - Ok::<_, Error>((session_id, key_data)) - }) - .filter_map(Result::ok) - .collect()) - } - - pub(super) fn get_session( - &self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, - ) -> Result>> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(version.as_bytes()); - key.push(0xFF); - key.extend_from_slice(room_id.as_bytes()); - key.push(0xFF); - key.extend_from_slice(session_id.as_bytes()); - - self.backupkeyid_backup - .get(&key)? - .map(|value| { - serde_json::from_slice(&value) - .map_err(|_| Error::bad_database("KeyBackupData in backupkeyid_backup is invalid.")) - }) - .transpose() - } - - pub(super) fn delete_all_keys(&self, user_id: &UserId, version: &str) -> Result<()> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(version.as_bytes()); - key.push(0xFF); - - for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) { - self.backupkeyid_backup.remove(&outdated_key)?; - } - - Ok(()) - } - - pub(super) fn delete_room_keys(&self, user_id: &UserId, version: &str, room_id: &RoomId) -> Result<()> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(version.as_bytes()); - key.push(0xFF); - key.extend_from_slice(room_id.as_bytes()); - key.push(0xFF); - - for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) { - self.backupkeyid_backup.remove(&outdated_key)?; - } - - Ok(()) - } - - pub(super) fn delete_room_key( - &self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, - ) -> Result<()> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(version.as_bytes()); - key.push(0xFF); - key.extend_from_slice(room_id.as_bytes()); - key.push(0xFF); - key.extend_from_slice(session_id.as_bytes()); - - for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) { - self.backupkeyid_backup.remove(&outdated_key)?; - } - - Ok(()) - } -} diff --git a/src/service/key_backups/mod.rs b/src/service/key_backups/mod.rs index 65d3c065e..decf32f7f 100644 --- a/src/service/key_backups/mod.rs +++ b/src/service/key_backups/mod.rs @@ -1,93 +1,311 @@ -mod data; - use std::{collections::BTreeMap, sync::Arc}; -use conduit::Result; -use data::Data; +use conduit::{ + err, implement, utils, + utils::stream::{ReadyExt, TryIgnore}, + Err, Error, Result, +}; +use database::{Deserialized, Ignore, Interfix, Map}; +use futures::StreamExt; use ruma::{ api::client::backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup}, serde::Raw, OwnedRoomId, RoomId, UserId, }; +use crate::{globals, Dep}; + pub struct Service { db: Data, + services: Services, +} + +struct Data { + backupid_algorithm: Arc, + backupid_etag: Arc, + backupkeyid_backup: Arc, +} + +struct Services { + globals: Dep, } impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { - db: Data::new(&args), + db: Data { + backupid_algorithm: args.db["backupid_algorithm"].clone(), + backupid_etag: args.db["backupid_etag"].clone(), + backupkeyid_backup: args.db["backupkeyid_backup"].clone(), + }, + services: Services { + globals: args.depend::("globals"), + }, })) } fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } -impl Service { - pub fn create_backup(&self, user_id: &UserId, backup_metadata: &Raw) -> Result { - self.db.create_backup(user_id, backup_metadata) - } +#[implement(Service)] +pub fn create_backup(&self, user_id: &UserId, backup_metadata: &Raw) -> Result { + let version = self.services.globals.next_count()?.to_string(); - pub fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()> { - self.db.delete_backup(user_id, version) - } + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(version.as_bytes()); - pub fn update_backup( - &self, user_id: &UserId, version: &str, backup_metadata: &Raw, - ) -> Result { - self.db.update_backup(user_id, version, backup_metadata) - } + self.db.backupid_algorithm.insert( + &key, + &serde_json::to_vec(backup_metadata).expect("BackupAlgorithm::to_vec always works"), + ); - pub fn get_latest_backup_version(&self, user_id: &UserId) -> Result> { - self.db.get_latest_backup_version(user_id) - } + self.db + .backupid_etag + .insert(&key, &self.services.globals.next_count()?.to_be_bytes()); - pub fn get_latest_backup(&self, user_id: &UserId) -> Result)>> { - self.db.get_latest_backup(user_id) - } + Ok(version) +} - pub fn get_backup(&self, user_id: &UserId, version: &str) -> Result>> { - self.db.get_backup(user_id, version) - } +#[implement(Service)] +pub async fn delete_backup(&self, user_id: &UserId, version: &str) { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(version.as_bytes()); - pub fn add_key( - &self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, key_data: &Raw, - ) -> Result<()> { - self.db - .add_key(user_id, version, room_id, session_id, key_data) + self.db.backupid_algorithm.remove(&key); + self.db.backupid_etag.remove(&key); + + let key = (user_id, version, Interfix); + self.db + .backupkeyid_backup + .keys_raw_prefix(&key) + .ignore_err() + .ready_for_each(|outdated_key| self.db.backupkeyid_backup.remove(outdated_key)) + .await; +} + +#[implement(Service)] +pub async fn update_backup( + &self, user_id: &UserId, version: &str, backup_metadata: &Raw, +) -> Result { + let key = (user_id, version); + if self.db.backupid_algorithm.qry(&key).await.is_err() { + return Err!(Request(NotFound("Tried to update nonexistent backup."))); } - pub fn count_keys(&self, user_id: &UserId, version: &str) -> Result { self.db.count_keys(user_id, version) } + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(version.as_bytes()); - pub fn get_etag(&self, user_id: &UserId, version: &str) -> Result { self.db.get_etag(user_id, version) } + self.db + .backupid_algorithm + .insert(&key, backup_metadata.json().get().as_bytes()); + self.db + .backupid_etag + .insert(&key, &self.services.globals.next_count()?.to_be_bytes()); - pub fn get_all(&self, user_id: &UserId, version: &str) -> Result> { - self.db.get_all(user_id, version) - } + Ok(version.to_owned()) +} - pub fn get_room( - &self, user_id: &UserId, version: &str, room_id: &RoomId, - ) -> Result>> { - self.db.get_room(user_id, version, room_id) - } +#[implement(Service)] +pub async fn get_latest_backup_version(&self, user_id: &UserId) -> Result { + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xFF); + let mut last_possible_key = prefix.clone(); + last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes()); - pub fn get_session( - &self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, - ) -> Result>> { - self.db.get_session(user_id, version, room_id, session_id) - } + self.db + .backupid_algorithm + .rev_raw_keys_from(&last_possible_key) + .ignore_err() + .ready_take_while(move |key| key.starts_with(&prefix)) + .next() + .await + .ok_or_else(|| err!(Request(NotFound("No backup versions found")))) + .and_then(|key| { + utils::string_from_bytes( + key.rsplit(|&b| b == 0xFF) + .next() + .expect("rsplit always returns an element"), + ) + .map_err(|_| Error::bad_database("backupid_algorithm key is invalid.")) + }) +} - pub fn delete_all_keys(&self, user_id: &UserId, version: &str) -> Result<()> { - self.db.delete_all_keys(user_id, version) - } +#[implement(Service)] +pub async fn get_latest_backup(&self, user_id: &UserId) -> Result<(String, Raw)> { + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xFF); + let mut last_possible_key = prefix.clone(); + last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes()); - pub fn delete_room_keys(&self, user_id: &UserId, version: &str, room_id: &RoomId) -> Result<()> { - self.db.delete_room_keys(user_id, version, room_id) - } + self.db + .backupid_algorithm + .rev_raw_stream_from(&last_possible_key) + .ignore_err() + .ready_take_while(move |(key, _)| key.starts_with(&prefix)) + .next() + .await + .ok_or_else(|| err!(Request(NotFound("No backup found")))) + .and_then(|(key, val)| { + let version = utils::string_from_bytes( + key.rsplit(|&b| b == 0xFF) + .next() + .expect("rsplit always returns an element"), + ) + .map_err(|_| Error::bad_database("backupid_algorithm key is invalid."))?; + + let algorithm = serde_json::from_slice(val) + .map_err(|_| Error::bad_database("Algorithm in backupid_algorithm is invalid."))?; - pub fn delete_room_key(&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str) -> Result<()> { - self.db - .delete_room_key(user_id, version, room_id, session_id) + Ok((version, algorithm)) + }) +} + +#[implement(Service)] +pub async fn get_backup(&self, user_id: &UserId, version: &str) -> Result> { + let key = (user_id, version); + self.db.backupid_algorithm.qry(&key).await.deserialized() +} + +#[implement(Service)] +pub async fn add_key( + &self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, key_data: &Raw, +) -> Result<()> { + let key = (user_id, version); + if self.db.backupid_algorithm.qry(&key).await.is_err() { + return Err!(Request(NotFound("Tried to update nonexistent backup."))); } + + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(version.as_bytes()); + + self.db + .backupid_etag + .insert(&key, &self.services.globals.next_count()?.to_be_bytes()); + + key.push(0xFF); + key.extend_from_slice(room_id.as_bytes()); + key.push(0xFF); + key.extend_from_slice(session_id.as_bytes()); + + self.db + .backupkeyid_backup + .insert(&key, key_data.json().get().as_bytes()); + + Ok(()) +} + +#[implement(Service)] +pub async fn count_keys(&self, user_id: &UserId, version: &str) -> usize { + let prefix = (user_id, version); + self.db + .backupkeyid_backup + .keys_raw_prefix(&prefix) + .count() + .await +} + +#[implement(Service)] +pub async fn get_etag(&self, user_id: &UserId, version: &str) -> String { + let key = (user_id, version); + self.db + .backupid_etag + .qry(&key) + .await + .deserialized::() + .as_ref() + .map(ToString::to_string) + .expect("Backup has no etag.") +} + +#[implement(Service)] +pub async fn get_all(&self, user_id: &UserId, version: &str) -> BTreeMap { + type KeyVal<'a> = ((Ignore, Ignore, &'a RoomId, &'a str), &'a [u8]); + + let mut rooms = BTreeMap::::new(); + let default = || RoomKeyBackup { + sessions: BTreeMap::new(), + }; + + let prefix = (user_id, version, Interfix); + self.db + .backupkeyid_backup + .stream_prefix(&prefix) + .ignore_err() + .ready_for_each(|((_, _, room_id, session_id), value): KeyVal<'_>| { + let key_data = serde_json::from_slice(value).expect("Invalid KeyBackupData JSON"); + rooms + .entry(room_id.into()) + .or_insert_with(default) + .sessions + .insert(session_id.into(), key_data); + }) + .await; + + rooms +} + +#[implement(Service)] +pub async fn get_room( + &self, user_id: &UserId, version: &str, room_id: &RoomId, +) -> BTreeMap> { + type KeyVal<'a> = ((Ignore, Ignore, Ignore, &'a str), &'a [u8]); + + let prefix = (user_id, version, room_id, Interfix); + self.db + .backupkeyid_backup + .stream_prefix(&prefix) + .ignore_err() + .map(|((.., session_id), value): KeyVal<'_>| { + let session_id = session_id.to_owned(); + let key_backup_data = serde_json::from_slice(value).expect("Invalid KeyBackupData JSON"); + (session_id, key_backup_data) + }) + .collect() + .await +} + +#[implement(Service)] +pub async fn get_session( + &self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, +) -> Result> { + let key = (user_id, version, room_id, session_id); + + self.db.backupkeyid_backup.qry(&key).await.deserialized() +} + +#[implement(Service)] +pub async fn delete_all_keys(&self, user_id: &UserId, version: &str) { + let key = (user_id, version, Interfix); + self.db + .backupkeyid_backup + .keys_raw_prefix(&key) + .ignore_err() + .ready_for_each(|outdated_key| self.db.backupkeyid_backup.remove(outdated_key)) + .await; +} + +#[implement(Service)] +pub async fn delete_room_keys(&self, user_id: &UserId, version: &str, room_id: &RoomId) { + let key = (user_id, version, room_id, Interfix); + self.db + .backupkeyid_backup + .keys_raw_prefix(&key) + .ignore_err() + .ready_for_each(|outdated_key| self.db.backupkeyid_backup.remove(outdated_key)) + .await; +} + +#[implement(Service)] +pub async fn delete_room_key(&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str) { + let key = (user_id, version, room_id, session_id); + self.db + .backupkeyid_backup + .keys_raw_prefix(&key) + .ignore_err() + .ready_for_each(|outdated_key| self.db.backupkeyid_backup.remove(outdated_key)) + .await; } diff --git a/src/service/manager.rs b/src/service/manager.rs index 42260bb30..21e0ed7c2 100644 --- a/src/service/manager.rs +++ b/src/service/manager.rs @@ -1,7 +1,7 @@ use std::{panic::AssertUnwindSafe, sync::Arc, time::Duration}; use conduit::{debug, debug_warn, error, trace, utils::time, warn, Err, Error, Result, Server}; -use futures_util::FutureExt; +use futures::FutureExt; use tokio::{ sync::{Mutex, MutexGuard}, task::{JoinHandle, JoinSet}, diff --git a/src/service/media/data.rs b/src/service/media/data.rs index e5d6d20b1..248e9e1d2 100644 --- a/src/service/media/data.rs +++ b/src/service/media/data.rs @@ -2,10 +2,11 @@ use std::sync::Arc; use conduit::{ debug, debug_info, trace, - utils::{str_from_bytes, string_from_bytes}, + utils::{str_from_bytes, stream::TryIgnore, string_from_bytes, ReadyExt}, Err, Error, Result, }; use database::{Database, Map}; +use futures::StreamExt; use ruma::{api::client::error::ErrorKind, http_headers::ContentDisposition, Mxc, OwnedMxcUri, UserId}; use super::{preview::UrlPreviewData, thumbnail::Dim}; @@ -59,7 +60,7 @@ impl Data { .unwrap_or_default(), ); - self.mediaid_file.insert(&key, &[])?; + self.mediaid_file.insert(&key, &[]); if let Some(user) = user { let mut key: Vec = Vec::new(); @@ -68,13 +69,13 @@ impl Data { key.extend_from_slice(b"/"); key.extend_from_slice(mxc.media_id.as_bytes()); let user = user.as_bytes().to_vec(); - self.mediaid_user.insert(&key, &user)?; + self.mediaid_user.insert(&key, &user); } Ok(key) } - pub(super) fn delete_file_mxc(&self, mxc: &Mxc<'_>) -> Result<()> { + pub(super) async fn delete_file_mxc(&self, mxc: &Mxc<'_>) { debug!("MXC URI: {mxc}"); let mut prefix: Vec = Vec::new(); @@ -85,25 +86,31 @@ impl Data { prefix.push(0xFF); trace!("MXC db prefix: {prefix:?}"); - for (key, _) in self.mediaid_file.scan_prefix(prefix.clone()) { - debug!("Deleting key: {:?}", key); - self.mediaid_file.remove(&key)?; - } - - for (key, value) in self.mediaid_user.scan_prefix(prefix.clone()) { - if key.starts_with(&prefix) { - let user = str_from_bytes(&value).unwrap_or_default(); - - debug_info!("Deleting key \"{key:?}\" which was uploaded by user {user}"); - self.mediaid_user.remove(&key)?; - } - } + self.mediaid_file + .raw_keys_prefix(&prefix) + .ignore_err() + .ready_for_each(|key| { + debug!("Deleting key: {:?}", key); + self.mediaid_file.remove(key); + }) + .await; - Ok(()) + self.mediaid_user + .raw_stream_prefix(&prefix) + .ignore_err() + .ready_for_each(|(key, val)| { + if key.starts_with(&prefix) { + let user = str_from_bytes(val).unwrap_or_default(); + debug_info!("Deleting key {key:?} which was uploaded by user {user}"); + + self.mediaid_user.remove(key); + } + }) + .await; } /// Searches for all files with the given MXC - pub(super) fn search_mxc_metadata_prefix(&self, mxc: &Mxc<'_>) -> Result>> { + pub(super) async fn search_mxc_metadata_prefix(&self, mxc: &Mxc<'_>) -> Result>> { debug!("MXC URI: {mxc}"); let mut prefix: Vec = Vec::new(); @@ -115,9 +122,10 @@ impl Data { let keys: Vec> = self .mediaid_file - .scan_prefix(prefix) - .map(|(key, _)| key) - .collect(); + .keys_prefix_raw(&prefix) + .ignore_err() + .collect() + .await; if keys.is_empty() { return Err!(Database("Failed to find any keys in database for `{mxc}`",)); @@ -128,7 +136,7 @@ impl Data { Ok(keys) } - pub(super) fn search_file_metadata(&self, mxc: &Mxc<'_>, dim: &Dim) -> Result { + pub(super) async fn search_file_metadata(&self, mxc: &Mxc<'_>, dim: &Dim) -> Result { let mut prefix: Vec = Vec::new(); prefix.extend_from_slice(b"mxc://"); prefix.extend_from_slice(mxc.server_name.as_bytes()); @@ -139,10 +147,13 @@ impl Data { prefix.extend_from_slice(&dim.height.to_be_bytes()); prefix.push(0xFF); - let (key, _) = self + let key = self .mediaid_file - .scan_prefix(prefix) + .raw_keys_prefix(&prefix) + .ignore_err() + .map(ToOwned::to_owned) .next() + .await .ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Media not found"))?; let mut parts = key.rsplit(|&b| b == 0xFF); @@ -177,28 +188,31 @@ impl Data { } /// Gets all the MXCs associated with a user - pub(super) fn get_all_user_mxcs(&self, user_id: &UserId) -> Vec { - let user_id = user_id.as_bytes().to_vec(); - + pub(super) async fn get_all_user_mxcs(&self, user_id: &UserId) -> Vec { self.mediaid_user - .iter() - .filter_map(|(key, user)| { - if *user == user_id { - let mxc_s = string_from_bytes(&key).ok()?; - Some(OwnedMxcUri::from(mxc_s)) - } else { - None - } - }) + .stream() + .ignore_err() + .ready_filter_map(|(key, user): (&str, &UserId)| (user == user_id).then(|| key.into())) .collect() + .await } /// Gets all the media keys in our database (this includes all the metadata /// associated with it such as width, height, content-type, etc) - pub(crate) fn get_all_media_keys(&self) -> Vec> { self.mediaid_file.iter().map(|(key, _)| key).collect() } + pub(crate) async fn get_all_media_keys(&self) -> Vec> { + self.mediaid_file + .raw_keys() + .ignore_err() + .map(<[u8]>::to_vec) + .collect() + .await + } #[inline] - pub(super) fn remove_url_preview(&self, url: &str) -> Result<()> { self.url_previews.remove(url.as_bytes()) } + pub(super) fn remove_url_preview(&self, url: &str) -> Result<()> { + self.url_previews.remove(url.as_bytes()); + Ok(()) + } pub(super) fn set_url_preview( &self, url: &str, data: &UrlPreviewData, timestamp: std::time::Duration, @@ -233,11 +247,13 @@ impl Data { value.push(0xFF); value.extend_from_slice(&data.image_height.unwrap_or(0).to_be_bytes()); - self.url_previews.insert(url.as_bytes(), &value) + self.url_previews.insert(url.as_bytes(), &value); + + Ok(()) } - pub(super) fn get_url_preview(&self, url: &str) -> Option { - let values = self.url_previews.get(url.as_bytes()).ok()??; + pub(super) async fn get_url_preview(&self, url: &str) -> Result { + let values = self.url_previews.get(url).await?; let mut values = values.split(|&b| b == 0xFF); @@ -291,7 +307,7 @@ impl Data { x => x, }; - Some(UrlPreviewData { + Ok(UrlPreviewData { title, description, image, diff --git a/src/service/media/migrations.rs b/src/service/media/migrations.rs index 9968d25b7..2d1b39f9f 100644 --- a/src/service/media/migrations.rs +++ b/src/service/media/migrations.rs @@ -7,7 +7,11 @@ use std::{ time::Instant, }; -use conduit::{debug, debug_info, debug_warn, error, info, warn, Config, Result}; +use conduit::{ + debug, debug_info, debug_warn, error, info, + utils::{stream::TryIgnore, ReadyExt}, + warn, Config, Result, +}; use crate::{globals, Services}; @@ -23,12 +27,17 @@ pub(crate) async fn migrate_sha256_media(services: &Services) -> Result<()> { // Move old media files to new names let mut changes = Vec::<(PathBuf, PathBuf)>::new(); - for (key, _) in mediaid_file.iter() { - let old = services.media.get_media_file_b64(&key); - let new = services.media.get_media_file_sha256(&key); - debug!(?key, ?old, ?new, num = changes.len(), "change"); - changes.push((old, new)); - } + mediaid_file + .raw_keys() + .ignore_err() + .ready_for_each(|key| { + let old = services.media.get_media_file_b64(key); + let new = services.media.get_media_file_sha256(key); + debug!(?key, ?old, ?new, num = changes.len(), "change"); + changes.push((old, new)); + }) + .await; + // move the file to the new location for (old_path, path) in changes { if old_path.exists() { @@ -41,11 +50,11 @@ pub(crate) async fn migrate_sha256_media(services: &Services) -> Result<()> { // Apply fix from when sha256_media was backward-incompat and bumped the schema // version from 13 to 14. For users satisfying these conditions we can go back. - if services.globals.db.database_version()? == 14 && globals::migrations::DATABASE_VERSION == 13 { + if services.globals.db.database_version().await == 14 && globals::migrations::DATABASE_VERSION == 13 { services.globals.db.bump_database_version(13)?; } - db["global"].insert(b"feat_sha256_media", &[])?; + db["global"].insert(b"feat_sha256_media", &[]); info!("Finished applying sha256_media"); Ok(()) } @@ -71,7 +80,7 @@ pub(crate) async fn checkup_sha256_media(services: &Services) -> Result<()> { .filter_map(|ent| ent.map_or(None, |ent| Some(ent.path().into_os_string()))) .collect(); - for key in media.db.get_all_media_keys() { + for key in media.db.get_all_media_keys().await { let new_path = media.get_media_file_sha256(&key).into_os_string(); let old_path = media.get_media_file_b64(&key).into_os_string(); if let Err(e) = handle_media_check(&dbs, config, &files, &key, &new_path, &old_path).await { @@ -112,8 +121,8 @@ async fn handle_media_check( "Media is missing at all paths. Removing from database..." ); - mediaid_file.remove(key)?; - mediaid_user.remove(key)?; + mediaid_file.remove(key); + mediaid_user.remove(key); } if config.media_compat_file_link && !old_exists && new_exists { diff --git a/src/service/media/mod.rs b/src/service/media/mod.rs index d3765a176..c0b15726f 100644 --- a/src/service/media/mod.rs +++ b/src/service/media/mod.rs @@ -97,7 +97,7 @@ impl Service { /// Deletes a file in the database and from the media directory via an MXC pub async fn delete(&self, mxc: &Mxc<'_>) -> Result<()> { - if let Ok(keys) = self.db.search_mxc_metadata_prefix(mxc) { + if let Ok(keys) = self.db.search_mxc_metadata_prefix(mxc).await { for key in keys { trace!(?mxc, "MXC Key: {key:?}"); debug_info!(?mxc, "Deleting from filesystem"); @@ -107,7 +107,7 @@ impl Service { } debug_info!(?mxc, "Deleting from database"); - _ = self.db.delete_file_mxc(mxc); + self.db.delete_file_mxc(mxc).await; } Ok(()) @@ -120,7 +120,7 @@ impl Service { /// /// currently, this is only practical for local users pub async fn delete_from_user(&self, user: &UserId) -> Result { - let mxcs = self.db.get_all_user_mxcs(user); + let mxcs = self.db.get_all_user_mxcs(user).await; let mut deletion_count: usize = 0; for mxc in mxcs { @@ -150,7 +150,7 @@ impl Service { content_disposition, content_type, key, - }) = self.db.search_file_metadata(mxc, &Dim::default()) + }) = self.db.search_file_metadata(mxc, &Dim::default()).await { let mut content = Vec::new(); let path = self.get_media_file(&key); @@ -170,7 +170,7 @@ impl Service { /// Gets all the MXC URIs in our media database pub async fn get_all_mxcs(&self) -> Result> { - let all_keys = self.db.get_all_media_keys(); + let all_keys = self.db.get_all_media_keys().await; let mut mxcs = Vec::with_capacity(all_keys.len()); @@ -209,7 +209,7 @@ impl Service { pub async fn delete_all_remote_media_at_after_time( &self, time: SystemTime, before: bool, after: bool, yes_i_want_to_delete_local_media: bool, ) -> Result { - let all_keys = self.db.get_all_media_keys(); + let all_keys = self.db.get_all_media_keys().await; let mut remote_mxcs = Vec::with_capacity(all_keys.len()); for key in all_keys { @@ -343,9 +343,10 @@ impl Service { } #[inline] - pub fn get_metadata(&self, mxc: &Mxc<'_>) -> Option { + pub async fn get_metadata(&self, mxc: &Mxc<'_>) -> Option { self.db .search_file_metadata(mxc, &Dim::default()) + .await .map(|metadata| FileMeta { content_disposition: metadata.content_disposition, content_type: metadata.content_type, diff --git a/src/service/media/preview.rs b/src/service/media/preview.rs index 5704075e5..6b1473838 100644 --- a/src/service/media/preview.rs +++ b/src/service/media/preview.rs @@ -71,16 +71,16 @@ pub async fn download_image(&self, url: &str) -> Result { #[implement(Service)] pub async fn get_url_preview(&self, url: &str) -> Result { - if let Some(preview) = self.db.get_url_preview(url) { + if let Ok(preview) = self.db.get_url_preview(url).await { return Ok(preview); } // ensure that only one request is made per URL let _request_lock = self.url_preview_mutex.lock(url).await; - match self.db.get_url_preview(url) { - Some(preview) => Ok(preview), - None => self.request_url_preview(url).await, + match self.db.get_url_preview(url).await { + Ok(preview) => Ok(preview), + Err(_) => self.request_url_preview(url).await, } } diff --git a/src/service/media/thumbnail.rs b/src/service/media/thumbnail.rs index 630f7b3b1..04ec03039 100644 --- a/src/service/media/thumbnail.rs +++ b/src/service/media/thumbnail.rs @@ -54,9 +54,9 @@ impl super::Service { // 0, 0 because that's the original file let dim = dim.normalized(); - if let Ok(metadata) = self.db.search_file_metadata(mxc, &dim) { + if let Ok(metadata) = self.db.search_file_metadata(mxc, &dim).await { self.get_thumbnail_saved(metadata).await - } else if let Ok(metadata) = self.db.search_file_metadata(mxc, &Dim::default()) { + } else if let Ok(metadata) = self.db.search_file_metadata(mxc, &Dim::default()).await { self.get_thumbnail_generate(mxc, &dim, metadata).await } else { Ok(None) diff --git a/src/service/mod.rs b/src/service/mod.rs index f588a5420..cb8bfcd95 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -19,6 +19,7 @@ pub mod resolver; pub mod rooms; pub mod sending; pub mod server_keys; +pub mod sync; pub mod transaction_ids; pub mod uiaa; pub mod updates; diff --git a/src/service/presence/data.rs b/src/service/presence/data.rs index ec036b3d6..9c9d0ae3f 100644 --- a/src/service/presence/data.rs +++ b/src/service/presence/data.rs @@ -1,7 +1,12 @@ use std::sync::Arc; -use conduit::{debug_warn, utils, Error, Result}; -use database::Map; +use conduit::{ + debug_warn, utils, + utils::{stream::TryIgnore, ReadyExt}, + Result, +}; +use database::{Deserialized, Map}; +use futures::Stream; use ruma::{events::presence::PresenceEvent, presence::PresenceState, OwnedUserId, UInt, UserId}; use super::Presence; @@ -31,39 +36,35 @@ impl Data { } } - pub fn get_presence(&self, user_id: &UserId) -> Result> { - if let Some(count_bytes) = self.userid_presenceid.get(user_id.as_bytes())? { - let count = utils::u64_from_bytes(&count_bytes) - .map_err(|_e| Error::bad_database("No 'count' bytes in presence key"))?; - - let key = presenceid_key(count, user_id); - self.presenceid_presence - .get(&key)? - .map(|presence_bytes| -> Result<(u64, PresenceEvent)> { - Ok(( - count, - Presence::from_json_bytes(&presence_bytes)?.to_presence_event(user_id, &self.services.users)?, - )) - }) - .transpose() - } else { - Ok(None) - } + pub async fn get_presence(&self, user_id: &UserId) -> Result<(u64, PresenceEvent)> { + let count = self + .userid_presenceid + .get(user_id) + .await + .deserialized::()?; + + let key = presenceid_key(count, user_id); + let bytes = self.presenceid_presence.get(&key).await?; + let event = Presence::from_json_bytes(&bytes)? + .to_presence_event(user_id, &self.services.users) + .await; + + Ok((count, event)) } - pub(super) fn set_presence( + pub(super) async fn set_presence( &self, user_id: &UserId, presence_state: &PresenceState, currently_active: Option, last_active_ago: Option, status_msg: Option, ) -> Result<()> { - let last_presence = self.get_presence(user_id)?; + let last_presence = self.get_presence(user_id).await; let state_changed = match last_presence { - None => true, - Some(ref presence) => presence.1.content.presence != *presence_state, + Err(_) => true, + Ok(ref presence) => presence.1.content.presence != *presence_state, }; let status_msg_changed = match last_presence { - None => true, - Some(ref last_presence) => { + Err(_) => true, + Ok(ref last_presence) => { let old_msg = last_presence .1 .content @@ -79,8 +80,8 @@ impl Data { let now = utils::millis_since_unix_epoch(); let last_last_active_ts = match last_presence { - None => 0, - Some((_, ref presence)) => now.saturating_sub(presence.content.last_active_ago.unwrap_or_default().into()), + Err(_) => 0, + Ok((_, ref presence)) => now.saturating_sub(presence.content.last_active_ago.unwrap_or_default().into()), }; let last_active_ts = match last_active_ago { @@ -90,12 +91,7 @@ impl Data { // TODO: tighten for state flicker? if !status_msg_changed && !state_changed && last_active_ts < last_last_active_ts { - debug_warn!( - "presence spam {:?} last_active_ts:{:?} < {:?}", - user_id, - last_active_ts, - last_last_active_ts - ); + debug_warn!("presence spam {user_id:?} last_active_ts:{last_active_ts:?} < {last_last_active_ts:?}",); return Ok(()); } @@ -115,41 +111,42 @@ impl Data { let key = presenceid_key(count, user_id); self.presenceid_presence - .insert(&key, &presence.to_json_bytes()?)?; + .insert(&key, &presence.to_json_bytes()?); self.userid_presenceid - .insert(user_id.as_bytes(), &count.to_be_bytes())?; + .insert(user_id.as_bytes(), &count.to_be_bytes()); - if let Some((last_count, _)) = last_presence { + if let Ok((last_count, _)) = last_presence { let key = presenceid_key(last_count, user_id); - self.presenceid_presence.remove(&key)?; + self.presenceid_presence.remove(&key); } Ok(()) } - pub(super) fn remove_presence(&self, user_id: &UserId) -> Result<()> { - if let Some(count_bytes) = self.userid_presenceid.get(user_id.as_bytes())? { - let count = utils::u64_from_bytes(&count_bytes) - .map_err(|_e| Error::bad_database("No 'count' bytes in presence key"))?; - let key = presenceid_key(count, user_id); - self.presenceid_presence.remove(&key)?; - self.userid_presenceid.remove(user_id.as_bytes())?; - } + pub(super) async fn remove_presence(&self, user_id: &UserId) { + let Ok(count) = self + .userid_presenceid + .get(user_id) + .await + .deserialized::() + else { + return; + }; - Ok(()) + let key = presenceid_key(count, user_id); + self.presenceid_presence.remove(&key); + self.userid_presenceid.remove(user_id.as_bytes()); } - pub fn presence_since<'a>(&'a self, since: u64) -> Box)> + 'a> { - Box::new( - self.presenceid_presence - .iter() - .flat_map(|(key, presence_bytes)| -> Result<(OwnedUserId, u64, Vec)> { - let (count, user_id) = presenceid_parse(&key)?; - Ok((user_id.to_owned(), count, presence_bytes)) - }) - .filter(move |(_, count, _)| *count > since), - ) + pub fn presence_since(&self, since: u64) -> impl Stream)> + Send + '_ { + self.presenceid_presence + .raw_stream() + .ignore_err() + .ready_filter_map(move |(key, presence_bytes)| { + let (count, user_id) = presenceid_parse(key).expect("invalid presenceid_parse"); + (count > since).then(|| (user_id.to_owned(), count, presence_bytes.to_vec())) + }) } } @@ -162,7 +159,7 @@ fn presenceid_key(count: u64, user_id: &UserId) -> Vec { fn presenceid_parse(key: &[u8]) -> Result<(u64, &UserId)> { let (count, user_id) = key.split_at(8); let user_id = user_id_from_bytes(user_id)?; - let count = utils::u64_from_bytes(count).unwrap(); + let count = utils::u64_from_u8(count); Ok((count, user_id)) } diff --git a/src/service/presence/mod.rs b/src/service/presence/mod.rs index a54a6d7c5..3b5c4caf4 100644 --- a/src/service/presence/mod.rs +++ b/src/service/presence/mod.rs @@ -4,8 +4,8 @@ mod presence; use std::{sync::Arc, time::Duration}; use async_trait::async_trait; -use conduit::{checked, debug, error, Error, Result, Server}; -use futures_util::{stream::FuturesUnordered, StreamExt}; +use conduit::{checked, debug, error, result::LogErr, Error, Result, Server}; +use futures::{stream::FuturesUnordered, Stream, StreamExt, TryFutureExt}; use ruma::{events::presence::PresenceEvent, presence::PresenceState, OwnedUserId, UInt, UserId}; use tokio::{sync::Mutex, time::sleep}; @@ -58,7 +58,9 @@ impl crate::Service for Service { loop { debug_assert!(!receiver.is_closed(), "channel error"); tokio::select! { - Some(user_id) = presence_timers.next() => self.process_presence_timer(&user_id)?, + Some(user_id) = presence_timers.next() => { + self.process_presence_timer(&user_id).await.log_err().ok(); + }, event = receiver.recv_async() => match event { Err(_e) => return Ok(()), Ok((user_id, timeout)) => { @@ -82,28 +84,27 @@ impl crate::Service for Service { impl Service { /// Returns the latest presence event for the given user. #[inline] - pub fn get_presence(&self, user_id: &UserId) -> Result> { - if let Some((_, presence)) = self.db.get_presence(user_id)? { - Ok(Some(presence)) - } else { - Ok(None) - } + pub async fn get_presence(&self, user_id: &UserId) -> Result { + self.db + .get_presence(user_id) + .map_ok(|(_, presence)| presence) + .await } /// Pings the presence of the given user in the given room, setting the /// specified state. - pub fn ping_presence(&self, user_id: &UserId, new_state: &PresenceState) -> Result<()> { + pub async fn ping_presence(&self, user_id: &UserId, new_state: &PresenceState) -> Result<()> { const REFRESH_TIMEOUT: u64 = 60 * 25 * 1000; - let last_presence = self.db.get_presence(user_id)?; + let last_presence = self.db.get_presence(user_id).await; let state_changed = match last_presence { - None => true, - Some((_, ref presence)) => presence.content.presence != *new_state, + Err(_) => true, + Ok((_, ref presence)) => presence.content.presence != *new_state, }; let last_last_active_ago = match last_presence { - None => 0_u64, - Some((_, ref presence)) => presence.content.last_active_ago.unwrap_or_default().into(), + Err(_) => 0_u64, + Ok((_, ref presence)) => presence.content.last_active_ago.unwrap_or_default().into(), }; if !state_changed && last_last_active_ago < REFRESH_TIMEOUT { @@ -111,17 +112,18 @@ impl Service { } let status_msg = match last_presence { - Some((_, ref presence)) => presence.content.status_msg.clone(), - None => Some(String::new()), + Ok((_, ref presence)) => presence.content.status_msg.clone(), + Err(_) => Some(String::new()), }; let last_active_ago = UInt::new(0); let currently_active = *new_state == PresenceState::Online; self.set_presence(user_id, new_state, Some(currently_active), last_active_ago, status_msg) + .await } /// Adds a presence event which will be saved until a new event replaces it. - pub fn set_presence( + pub async fn set_presence( &self, user_id: &UserId, state: &PresenceState, currently_active: Option, last_active_ago: Option, status_msg: Option, ) -> Result<()> { @@ -131,7 +133,8 @@ impl Service { }; self.db - .set_presence(user_id, presence_state, currently_active, last_active_ago, status_msg)?; + .set_presence(user_id, presence_state, currently_active, last_active_ago, status_msg) + .await?; if self.timeout_remote_users || self.services.globals.user_is_local(user_id) { let timeout = match presence_state { @@ -154,28 +157,33 @@ impl Service { /// /// TODO: Why is this not used? #[allow(dead_code)] - pub fn remove_presence(&self, user_id: &UserId) -> Result<()> { self.db.remove_presence(user_id) } + pub async fn remove_presence(&self, user_id: &UserId) { self.db.remove_presence(user_id).await } /// Returns the most recent presence updates that happened after the event /// with id `since`. #[inline] - pub fn presence_since(&self, since: u64) -> Box)> + '_> { + pub fn presence_since(&self, since: u64) -> impl Stream)> + Send + '_ { self.db.presence_since(since) } - pub fn from_json_bytes_to_event(&self, bytes: &[u8], user_id: &UserId) -> Result { + #[inline] + pub async fn from_json_bytes_to_event(&self, bytes: &[u8], user_id: &UserId) -> Result { let presence = Presence::from_json_bytes(bytes)?; - presence.to_presence_event(user_id, &self.services.users) + let event = presence + .to_presence_event(user_id, &self.services.users) + .await; + + Ok(event) } - fn process_presence_timer(&self, user_id: &OwnedUserId) -> Result<()> { + async fn process_presence_timer(&self, user_id: &OwnedUserId) -> Result<()> { let mut presence_state = PresenceState::Offline; let mut last_active_ago = None; let mut status_msg = None; - let presence_event = self.get_presence(user_id)?; + let presence_event = self.get_presence(user_id).await; - if let Some(presence_event) = presence_event { + if let Ok(presence_event) = presence_event { presence_state = presence_event.content.presence; last_active_ago = presence_event.content.last_active_ago; status_msg = presence_event.content.status_msg; @@ -192,7 +200,8 @@ impl Service { ); if let Some(new_state) = new_state { - self.set_presence(user_id, &new_state, Some(false), last_active_ago, status_msg)?; + self.set_presence(user_id, &new_state, Some(false), last_active_ago, status_msg) + .await?; } Ok(()) diff --git a/src/service/presence/presence.rs b/src/service/presence/presence.rs index 570008f29..0d5c226bf 100644 --- a/src/service/presence/presence.rs +++ b/src/service/presence/presence.rs @@ -1,5 +1,3 @@ -use std::sync::Arc; - use conduit::{utils, Error, Result}; use ruma::{ events::presence::{PresenceEvent, PresenceEventContent}, @@ -42,7 +40,7 @@ impl Presence { } /// Creates a PresenceEvent from available data. - pub(super) fn to_presence_event(&self, user_id: &UserId, users: &Arc) -> Result { + pub(super) async fn to_presence_event(&self, user_id: &UserId, users: &users::Service) -> PresenceEvent { let now = utils::millis_since_unix_epoch(); let last_active_ago = if self.currently_active { None @@ -50,16 +48,16 @@ impl Presence { Some(UInt::new_saturating(now.saturating_sub(self.last_active_ts))) }; - Ok(PresenceEvent { + PresenceEvent { sender: user_id.to_owned(), content: PresenceEventContent { presence: self.state.clone(), status_msg: self.status_msg.clone(), currently_active: Some(self.currently_active), last_active_ago, - displayname: users.displayname(user_id)?, - avatar_url: users.avatar_url(user_id)?, + displayname: users.displayname(user_id).await.ok(), + avatar_url: users.avatar_url(user_id).await.ok(), }, - }) + } } } diff --git a/src/service/pusher/data.rs b/src/service/pusher/data.rs deleted file mode 100644 index f97343341..000000000 --- a/src/service/pusher/data.rs +++ /dev/null @@ -1,77 +0,0 @@ -use std::sync::Arc; - -use conduit::{utils, Error, Result}; -use database::{Database, Map}; -use ruma::{ - api::client::push::{set_pusher, Pusher}, - UserId, -}; - -pub(super) struct Data { - senderkey_pusher: Arc, -} - -impl Data { - pub(super) fn new(db: &Arc) -> Self { - Self { - senderkey_pusher: db["senderkey_pusher"].clone(), - } - } - - pub(super) fn set_pusher(&self, sender: &UserId, pusher: &set_pusher::v3::PusherAction) -> Result<()> { - match pusher { - set_pusher::v3::PusherAction::Post(data) => { - let mut key = sender.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(data.pusher.ids.pushkey.as_bytes()); - self.senderkey_pusher - .insert(&key, &serde_json::to_vec(pusher).expect("Pusher is valid JSON value"))?; - Ok(()) - }, - set_pusher::v3::PusherAction::Delete(ids) => { - let mut key = sender.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(ids.pushkey.as_bytes()); - self.senderkey_pusher.remove(&key).map_err(Into::into) - }, - } - } - - pub(super) fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result> { - let mut senderkey = sender.as_bytes().to_vec(); - senderkey.push(0xFF); - senderkey.extend_from_slice(pushkey.as_bytes()); - - self.senderkey_pusher - .get(&senderkey)? - .map(|push| serde_json::from_slice(&push).map_err(|_| Error::bad_database("Invalid Pusher in db."))) - .transpose() - } - - pub(super) fn get_pushers(&self, sender: &UserId) -> Result> { - let mut prefix = sender.as_bytes().to_vec(); - prefix.push(0xFF); - - self.senderkey_pusher - .scan_prefix(prefix) - .map(|(_, push)| serde_json::from_slice(&push).map_err(|_| Error::bad_database("Invalid Pusher in db."))) - .collect() - } - - pub(super) fn get_pushkeys<'a>(&'a self, sender: &UserId) -> Box> + 'a> { - let mut prefix = sender.as_bytes().to_vec(); - prefix.push(0xFF); - - Box::new(self.senderkey_pusher.scan_prefix(prefix).map(|(k, _)| { - let mut parts = k.splitn(2, |&b| b == 0xFF); - let _senderkey = parts.next(); - let push_key = parts - .next() - .ok_or_else(|| Error::bad_database("Invalid senderkey_pusher in db"))?; - let push_key_string = utils::string_from_bytes(push_key) - .map_err(|_| Error::bad_database("Invalid pusher bytes in senderkey_pusher"))?; - - Ok(push_key_string) - })) - } -} diff --git a/src/service/pusher/mod.rs b/src/service/pusher/mod.rs index de87264c9..8d8b553fe 100644 --- a/src/service/pusher/mod.rs +++ b/src/service/pusher/mod.rs @@ -1,9 +1,13 @@ -mod data; - use std::{fmt::Debug, mem, sync::Arc}; use bytes::BytesMut; -use conduit::{debug_error, err, trace, utils::string_from_bytes, warn, Err, PduEvent, Result}; +use conduit::{ + debug_error, err, trace, + utils::{stream::TryIgnore, string_from_bytes}, + Err, PduEvent, Result, +}; +use database::{Deserialized, Ignore, Interfix, Map}; +use futures::{Stream, StreamExt}; use ipaddress::IPAddress; use ruma::{ api::{ @@ -22,12 +26,11 @@ use ruma::{ uint, RoomId, UInt, UserId, }; -use self::data::Data; use crate::{client, globals, rooms, users, Dep}; pub struct Service { - services: Services, db: Data, + services: Services, } struct Services { @@ -38,9 +41,16 @@ struct Services { users: Dep, } +struct Data { + senderkey_pusher: Arc, +} + impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { + db: Data { + senderkey_pusher: args.db["senderkey_pusher"].clone(), + }, services: Services { globals: args.depend::("globals"), client: args.depend::("client"), @@ -48,7 +58,6 @@ impl crate::Service for Service { state_cache: args.depend::("rooms::state_cache"), users: args.depend::("users"), }, - db: Data::new(args.db), })) } @@ -56,19 +65,52 @@ impl crate::Service for Service { } impl Service { - pub fn set_pusher(&self, sender: &UserId, pusher: &set_pusher::v3::PusherAction) -> Result<()> { - self.db.set_pusher(sender, pusher) + pub fn set_pusher(&self, sender: &UserId, pusher: &set_pusher::v3::PusherAction) { + match pusher { + set_pusher::v3::PusherAction::Post(data) => { + let mut key = sender.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(data.pusher.ids.pushkey.as_bytes()); + self.db + .senderkey_pusher + .insert(&key, &serde_json::to_vec(pusher).expect("Pusher is valid JSON value")); + }, + set_pusher::v3::PusherAction::Delete(ids) => { + let mut key = sender.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(ids.pushkey.as_bytes()); + self.db.senderkey_pusher.remove(&key); + }, + } } - pub fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result> { - self.db.get_pusher(sender, pushkey) + pub async fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result { + let senderkey = (sender, pushkey); + self.db + .senderkey_pusher + .qry(&senderkey) + .await + .deserialized() } - pub fn get_pushers(&self, sender: &UserId) -> Result> { self.db.get_pushers(sender) } + pub async fn get_pushers(&self, sender: &UserId) -> Vec { + let prefix = (sender, Interfix); + self.db + .senderkey_pusher + .stream_prefix(&prefix) + .ignore_err() + .map(|(_, val): (Ignore, &[u8])| serde_json::from_slice(val).expect("Invalid Pusher in db.")) + .collect() + .await + } - #[must_use] - pub fn get_pushkeys(&self, sender: &UserId) -> Box> + '_> { - self.db.get_pushkeys(sender) + pub fn get_pushkeys<'a>(&'a self, sender: &'a UserId) -> impl Stream + Send + 'a { + let prefix = (sender, Interfix); + self.db + .senderkey_pusher + .keys_prefix(&prefix) + .ignore_err() + .map(|(_, pushkey): (Ignore, &str)| pushkey) } #[tracing::instrument(skip(self, dest, request))] @@ -161,15 +203,18 @@ impl Service { let power_levels: RoomPowerLevelsEventContent = self .services .state_accessor - .room_state_get(&pdu.room_id, &StateEventType::RoomPowerLevels, "")? - .map(|ev| { + .room_state_get(&pdu.room_id, &StateEventType::RoomPowerLevels, "") + .await + .and_then(|ev| { serde_json::from_str(ev.content.get()) - .map_err(|e| err!(Database("invalid m.room.power_levels event: {e:?}"))) + .map_err(|e| err!(Database(error!("invalid m.room.power_levels event: {e:?}")))) }) - .transpose()? .unwrap_or_default(); - for action in self.get_actions(user, &ruleset, &power_levels, &pdu.to_sync_room_event(), &pdu.room_id)? { + for action in self + .get_actions(user, &ruleset, &power_levels, &pdu.to_sync_room_event(), &pdu.room_id) + .await? + { let n = match action { Action::Notify => true, Action::SetTweak(tweak) => { @@ -197,7 +242,7 @@ impl Service { } #[tracing::instrument(skip(self, user, ruleset, pdu), level = "debug")] - pub fn get_actions<'a>( + pub async fn get_actions<'a>( &self, user: &UserId, ruleset: &'a Ruleset, power_levels: &RoomPowerLevelsEventContent, pdu: &Raw, room_id: &RoomId, ) -> Result<&'a [Action]> { @@ -207,21 +252,27 @@ impl Service { notifications: power_levels.notifications.clone(), }; + let room_joined_count = self + .services + .state_cache + .room_joined_count(room_id) + .await + .unwrap_or(1) + .try_into() + .unwrap_or_else(|_| uint!(0)); + + let user_display_name = self + .services + .users + .displayname(user) + .await + .unwrap_or_else(|_| user.localpart().to_owned()); + let ctx = PushConditionRoomCtx { room_id: room_id.to_owned(), - member_count: UInt::try_from( - self.services - .state_cache - .room_joined_count(room_id)? - .unwrap_or(1), - ) - .unwrap_or_else(|_| uint!(0)), + member_count: room_joined_count, user_id: user.to_owned(), - user_display_name: self - .services - .users - .displayname(user)? - .unwrap_or_else(|| user.localpart().to_owned()), + user_display_name, power_levels: Some(power_levels), }; @@ -278,9 +329,14 @@ impl Service { notifi.user_is_target = event.state_key.as_deref() == Some(event.sender.as_str()); } - notifi.sender_display_name = self.services.users.displayname(&event.sender)?; + notifi.sender_display_name = self.services.users.displayname(&event.sender).await.ok(); - notifi.room_name = self.services.state_accessor.get_name(&event.room_id)?; + notifi.room_name = self + .services + .state_accessor + .get_name(&event.room_id) + .await + .ok(); self.send_request(&http.url, send_event_notification::v1::Request::new(notifi)) .await?; diff --git a/src/service/resolver/actual.rs b/src/service/resolver/actual.rs index 07d9a0fae..ea4b1100f 100644 --- a/src/service/resolver/actual.rs +++ b/src/service/resolver/actual.rs @@ -193,7 +193,7 @@ impl super::Service { .send() .await; - trace!("response: {:?}", response); + trace!("response: {response:?}"); if let Err(e) = &response { debug!("error: {e:?}"); return Ok(None); @@ -206,7 +206,7 @@ impl super::Service { } let text = response.text().await?; - trace!("response text: {:?}", text); + trace!("response text: {text:?}"); if text.len() >= 12288 { debug_warn!("response contains junk"); return Ok(None); @@ -225,7 +225,7 @@ impl super::Service { return Ok(None); } - debug_info!("{:?} found at {:?}", dest, m_server); + debug_info!("{dest:?} found at {m_server:?}"); Ok(Some(m_server.to_owned())) } diff --git a/src/service/rooms/alias/data.rs b/src/service/rooms/alias/data.rs deleted file mode 100644 index efd2b5b76..000000000 --- a/src/service/rooms/alias/data.rs +++ /dev/null @@ -1,125 +0,0 @@ -use std::sync::Arc; - -use conduit::{utils, Error, Result}; -use database::Map; -use ruma::{api::client::error::ErrorKind, OwnedRoomAliasId, OwnedRoomId, OwnedUserId, RoomAliasId, RoomId, UserId}; - -use crate::{globals, Dep}; - -pub(super) struct Data { - alias_userid: Arc, - alias_roomid: Arc, - aliasid_alias: Arc, - services: Services, -} - -struct Services { - globals: Dep, -} - -impl Data { - pub(super) fn new(args: &crate::Args<'_>) -> Self { - let db = &args.db; - Self { - alias_userid: db["alias_userid"].clone(), - alias_roomid: db["alias_roomid"].clone(), - aliasid_alias: db["aliasid_alias"].clone(), - services: Services { - globals: args.depend::("globals"), - }, - } - } - - pub(super) fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId, user_id: &UserId) -> Result<()> { - // Comes first as we don't want a stuck alias - self.alias_userid - .insert(alias.alias().as_bytes(), user_id.as_bytes())?; - - self.alias_roomid - .insert(alias.alias().as_bytes(), room_id.as_bytes())?; - - let mut aliasid = room_id.as_bytes().to_vec(); - aliasid.push(0xFF); - aliasid.extend_from_slice(&self.services.globals.next_count()?.to_be_bytes()); - self.aliasid_alias.insert(&aliasid, alias.as_bytes())?; - - Ok(()) - } - - pub(super) fn remove_alias(&self, alias: &RoomAliasId) -> Result<()> { - if let Some(room_id) = self.alias_roomid.get(alias.alias().as_bytes())? { - let mut prefix = room_id.to_vec(); - prefix.push(0xFF); - - for (key, _) in self.aliasid_alias.scan_prefix(prefix) { - self.aliasid_alias.remove(&key)?; - } - - self.alias_roomid.remove(alias.alias().as_bytes())?; - - self.alias_userid.remove(alias.alias().as_bytes())?; - } else { - return Err(Error::BadRequest(ErrorKind::NotFound, "Alias does not exist or is invalid.")); - } - - Ok(()) - } - - pub(super) fn resolve_local_alias(&self, alias: &RoomAliasId) -> Result> { - self.alias_roomid - .get(alias.alias().as_bytes())? - .map(|bytes| { - RoomId::parse( - utils::string_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Room ID in alias_roomid is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("Room ID in alias_roomid is invalid.")) - }) - .transpose() - } - - pub(super) fn who_created_alias(&self, alias: &RoomAliasId) -> Result> { - self.alias_userid - .get(alias.alias().as_bytes())? - .map(|bytes| { - UserId::parse( - utils::string_from_bytes(&bytes) - .map_err(|_| Error::bad_database("User ID in alias_userid is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("User ID in alias_roomid is invalid.")) - }) - .transpose() - } - - pub(super) fn local_aliases_for_room<'a>( - &'a self, room_id: &RoomId, - ) -> Box> + 'a + Send> { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xFF); - - Box::new(self.aliasid_alias.scan_prefix(prefix).map(|(_, bytes)| { - utils::string_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Invalid alias bytes in aliasid_alias."))? - .try_into() - .map_err(|_| Error::bad_database("Invalid alias in aliasid_alias.")) - })) - } - - pub(super) fn all_local_aliases<'a>(&'a self) -> Box> + 'a> { - Box::new( - self.alias_roomid - .iter() - .map(|(room_alias_bytes, room_id_bytes)| { - let room_alias_localpart = utils::string_from_bytes(&room_alias_bytes) - .map_err(|_| Error::bad_database("Invalid alias bytes in aliasid_alias."))?; - - let room_id = utils::string_from_bytes(&room_id_bytes) - .map_err(|_| Error::bad_database("Invalid room_id bytes in aliasid_alias."))? - .try_into() - .map_err(|_| Error::bad_database("Invalid room_id in aliasid_alias."))?; - - Ok((room_id, room_alias_localpart)) - }), - ) - } -} diff --git a/src/service/rooms/alias/mod.rs b/src/service/rooms/alias/mod.rs index f2e01ab54..1d44cd2d8 100644 --- a/src/service/rooms/alias/mod.rs +++ b/src/service/rooms/alias/mod.rs @@ -1,19 +1,23 @@ -mod data; mod remote; use std::sync::Arc; -use conduit::{err, Error, Result}; +use conduit::{ + err, + utils::{stream::TryIgnore, ReadyExt}, + Err, Error, Result, +}; +use database::{Deserialized, Ignore, Interfix, Map}; +use futures::{Stream, StreamExt}; use ruma::{ api::client::error::ErrorKind, events::{ room::power_levels::{RoomPowerLevels, RoomPowerLevelsEventContent}, StateEventType, }, - OwnedRoomAliasId, OwnedRoomId, OwnedServerName, RoomAliasId, RoomId, RoomOrAliasId, UserId, + OwnedRoomId, OwnedServerName, OwnedUserId, RoomAliasId, RoomId, RoomOrAliasId, UserId, }; -use self::data::Data; use crate::{admin, appservice, appservice::RegistrationInfo, globals, rooms, sending, Dep}; pub struct Service { @@ -21,6 +25,12 @@ pub struct Service { services: Services, } +struct Data { + alias_userid: Arc, + alias_roomid: Arc, + aliasid_alias: Arc, +} + struct Services { admin: Dep, appservice: Dep, @@ -32,7 +42,11 @@ struct Services { impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { - db: Data::new(&args), + db: Data { + alias_userid: args.db["alias_userid"].clone(), + alias_roomid: args.db["alias_roomid"].clone(), + aliasid_alias: args.db["aliasid_alias"].clone(), + }, services: Services { admin: args.depend::("admin"), appservice: args.depend::("appservice"), @@ -50,25 +64,52 @@ impl Service { #[tracing::instrument(skip(self))] pub fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId, user_id: &UserId) -> Result<()> { if alias == self.services.globals.admin_alias && user_id != self.services.globals.server_user { - Err(Error::BadRequest( + return Err(Error::BadRequest( ErrorKind::forbidden(), "Only the server user can set this alias", - )) - } else { - self.db.set_alias(alias, room_id, user_id) + )); } + + // Comes first as we don't want a stuck alias + self.db + .alias_userid + .insert(alias.alias().as_bytes(), user_id.as_bytes()); + + self.db + .alias_roomid + .insert(alias.alias().as_bytes(), room_id.as_bytes()); + + let mut aliasid = room_id.as_bytes().to_vec(); + aliasid.push(0xFF); + aliasid.extend_from_slice(&self.services.globals.next_count()?.to_be_bytes()); + self.db.aliasid_alias.insert(&aliasid, alias.as_bytes()); + + Ok(()) } #[tracing::instrument(skip(self))] pub async fn remove_alias(&self, alias: &RoomAliasId, user_id: &UserId) -> Result<()> { - if self.user_can_remove_alias(alias, user_id).await? { - self.db.remove_alias(alias) - } else { - Err(Error::BadRequest( - ErrorKind::forbidden(), - "User is not permitted to remove this alias.", - )) + if !self.user_can_remove_alias(alias, user_id).await? { + return Err!(Request(Forbidden("User is not permitted to remove this alias."))); } + + let alias = alias.alias(); + let Ok(room_id) = self.db.alias_roomid.get(&alias).await else { + return Err!(Request(NotFound("Alias does not exist or is invalid."))); + }; + + let prefix = (&room_id, Interfix); + self.db + .aliasid_alias + .keys_prefix(&prefix) + .ignore_err() + .ready_for_each(|key: &[u8]| self.db.aliasid_alias.remove(&key)) + .await; + + self.db.alias_roomid.remove(alias.as_bytes()); + self.db.alias_userid.remove(alias.as_bytes()); + + Ok(()) } pub async fn resolve(&self, room: &RoomOrAliasId) -> Result { @@ -97,9 +138,9 @@ impl Service { return self.remote_resolve(room_alias, servers).await; } - let room_id: Option = match self.resolve_local_alias(room_alias)? { - Some(r) => Some(r), - None => self.resolve_appservice_alias(room_alias).await?, + let room_id: Option = match self.resolve_local_alias(room_alias).await { + Ok(r) => Some(r), + Err(_) => self.resolve_appservice_alias(room_alias).await?, }; room_id.map_or_else( @@ -109,46 +150,54 @@ impl Service { } #[tracing::instrument(skip(self), level = "debug")] - pub fn resolve_local_alias(&self, alias: &RoomAliasId) -> Result> { - self.db.resolve_local_alias(alias) + pub async fn resolve_local_alias(&self, alias: &RoomAliasId) -> Result { + self.db.alias_roomid.get(alias.alias()).await.deserialized() } #[tracing::instrument(skip(self), level = "debug")] - pub fn local_aliases_for_room<'a>( - &'a self, room_id: &RoomId, - ) -> Box> + 'a + Send> { - self.db.local_aliases_for_room(room_id) + pub fn local_aliases_for_room<'a>(&'a self, room_id: &'a RoomId) -> impl Stream + Send + 'a { + let prefix = (room_id, Interfix); + self.db + .aliasid_alias + .stream_prefix(&prefix) + .ignore_err() + .map(|((Ignore, Ignore), alias): ((Ignore, Ignore), &RoomAliasId)| alias) } #[tracing::instrument(skip(self), level = "debug")] - pub fn all_local_aliases<'a>(&'a self) -> Box> + 'a> { - self.db.all_local_aliases() + pub fn all_local_aliases<'a>(&'a self) -> impl Stream + Send + 'a { + self.db + .alias_roomid + .stream() + .ignore_err() + .map(|(alias_localpart, room_id): (&str, &RoomId)| (room_id, alias_localpart)) } async fn user_can_remove_alias(&self, alias: &RoomAliasId, user_id: &UserId) -> Result { - let Some(room_id) = self.resolve_local_alias(alias)? else { - return Err(Error::BadRequest(ErrorKind::NotFound, "Alias not found.")); - }; + let room_id = self + .resolve_local_alias(alias) + .await + .map_err(|_| err!(Request(NotFound("Alias not found."))))?; let server_user = &self.services.globals.server_user; // The creator of an alias can remove it if self - .db - .who_created_alias(alias)? - .is_some_and(|user| user == user_id) + .who_created_alias(alias).await + .is_ok_and(|user| user == user_id) // Server admins can remove any local alias - || self.services.admin.user_is_admin(user_id).await? + || self.services.admin.user_is_admin(user_id).await // Always allow the server service account to remove the alias, since there may not be an admin room || server_user == user_id { Ok(true) // Checking whether the user is able to change canonical aliases of the // room - } else if let Some(event) = - self.services - .state_accessor - .room_state_get(&room_id, &StateEventType::RoomPowerLevels, "")? + } else if let Ok(event) = self + .services + .state_accessor + .room_state_get(&room_id, &StateEventType::RoomPowerLevels, "") + .await { serde_json::from_str(event.content.get()) .map_err(|_| Error::bad_database("Invalid event content for m.room.power_levels")) @@ -157,10 +206,11 @@ impl Service { }) // If there is no power levels event, only the room creator can change // canonical aliases - } else if let Some(event) = - self.services - .state_accessor - .room_state_get(&room_id, &StateEventType::RoomCreate, "")? + } else if let Ok(event) = self + .services + .state_accessor + .room_state_get(&room_id, &StateEventType::RoomCreate, "") + .await { Ok(event.sender == user_id) } else { @@ -168,6 +218,10 @@ impl Service { } } + async fn who_created_alias(&self, alias: &RoomAliasId) -> Result { + self.db.alias_userid.get(alias.alias()).await.deserialized() + } + async fn resolve_appservice_alias(&self, room_alias: &RoomAliasId) -> Result> { use ruma::api::appservice::query::query_room_alias; @@ -185,10 +239,11 @@ impl Service { .await, Ok(Some(_opt_result)) ) { - return Ok(Some( - self.resolve_local_alias(room_alias)? - .ok_or_else(|| err!(Request(NotFound("Room does not exist."))))?, - )); + return self + .resolve_local_alias(room_alias) + .await + .map_err(|_| err!(Request(NotFound("Room does not exist.")))) + .map(Some); } } diff --git a/src/service/rooms/auth_chain/data.rs b/src/service/rooms/auth_chain/data.rs index 6e7c78359..5c9dbda83 100644 --- a/src/service/rooms/auth_chain/data.rs +++ b/src/service/rooms/auth_chain/data.rs @@ -3,7 +3,7 @@ use std::{ sync::{Arc, Mutex}, }; -use conduit::{utils, utils::math::usize_from_f64, Result}; +use conduit::{err, utils, utils::math::usize_from_f64, Err, Result}; use database::Map; use lru_cache::LruCache; @@ -24,57 +24,63 @@ impl Data { } } - pub(super) fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result>> { + pub(super) async fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result> { + debug_assert!(!key.is_empty(), "auth_chain key must not be empty"); + // Check RAM cache - if let Some(result) = self.auth_chain_cache.lock().unwrap().get_mut(key) { - return Ok(Some(Arc::clone(result))); + if let Some(result) = self + .auth_chain_cache + .lock() + .expect("cache locked") + .get_mut(key) + { + return Ok(Arc::clone(result)); } // We only save auth chains for single events in the db - if key.len() == 1 { - // Check DB cache - let chain = self - .shorteventid_authchain - .get(&key[0].to_be_bytes())? - .map(|chain| { - chain - .chunks_exact(size_of::()) - .map(utils::u64_from_u8) - .collect::>() - }); + if key.len() != 1 { + return Err!(Request(NotFound("auth_chain not cached"))); + } - if let Some(chain) = chain { - // Cache in RAM - self.auth_chain_cache - .lock() - .expect("locked") - .insert(vec![key[0]], Arc::clone(&chain)); + // Check database + let chain = self + .shorteventid_authchain + .qry(&key[0]) + .await + .map_err(|_| err!(Request(NotFound("auth_chain not found"))))?; - return Ok(Some(chain)); - } - } + let chain = chain + .chunks_exact(size_of::()) + .map(utils::u64_from_u8) + .collect::>(); + + // Cache in RAM + self.auth_chain_cache + .lock() + .expect("cache locked") + .insert(vec![key[0]], Arc::clone(&chain)); - Ok(None) + Ok(chain) } - pub(super) fn cache_auth_chain(&self, key: Vec, auth_chain: Arc<[u64]>) -> Result<()> { + pub(super) fn cache_auth_chain(&self, key: Vec, auth_chain: Arc<[u64]>) { + debug_assert!(!key.is_empty(), "auth_chain key must not be empty"); + // Only persist single events in db if key.len() == 1 { - self.shorteventid_authchain.insert( - &key[0].to_be_bytes(), - &auth_chain - .iter() - .flat_map(|s| s.to_be_bytes().to_vec()) - .collect::>(), - )?; + let key = key[0].to_be_bytes(); + let val = auth_chain + .iter() + .flat_map(|s| s.to_be_bytes().to_vec()) + .collect::>(); + + self.shorteventid_authchain.insert(&key, &val); } // Cache in RAM self.auth_chain_cache .lock() - .expect("locked") + .expect("cache locked") .insert(key, auth_chain); - - Ok(()) } } diff --git a/src/service/rooms/auth_chain/mod.rs b/src/service/rooms/auth_chain/mod.rs index 9a1e7e67a..f3861ca3f 100644 --- a/src/service/rooms/auth_chain/mod.rs +++ b/src/service/rooms/auth_chain/mod.rs @@ -5,7 +5,8 @@ use std::{ sync::Arc, }; -use conduit::{debug, error, trace, validated, warn, Err, Result}; +use conduit::{debug, debug_error, trace, utils::IterStream, validated, warn, Err, Result}; +use futures::Stream; use ruma::{EventId, RoomId}; use self::data::Data; @@ -36,19 +37,30 @@ impl crate::Service for Service { } impl Service { - pub async fn event_ids_iter<'a>( - &'a self, room_id: &RoomId, starting_events_: Vec>, - ) -> Result> + 'a> { - let mut starting_events: Vec<&EventId> = Vec::with_capacity(starting_events_.len()); - for starting_event in &starting_events_ { - starting_events.push(starting_event); - } - - Ok(self - .get_auth_chain(room_id, &starting_events) + pub async fn event_ids_iter( + &self, room_id: &RoomId, starting_events: &[&EventId], + ) -> Result> + Send + '_> { + let stream = self + .get_event_ids(room_id, starting_events) .await? .into_iter() - .filter_map(move |sid| self.services.short.get_eventid_from_short(sid).ok())) + .stream(); + + Ok(stream) + } + + pub async fn get_event_ids(&self, room_id: &RoomId, starting_events: &[&EventId]) -> Result>> { + let chain = self.get_auth_chain(room_id, starting_events).await?; + let event_ids = self + .services + .short + .multi_get_eventid_from_short(&chain) + .await + .into_iter() + .filter_map(Result::ok) + .collect(); + + Ok(event_ids) } #[tracing::instrument(skip_all, name = "auth_chain")] @@ -61,12 +73,13 @@ impl Service { for (i, &short) in self .services .short - .multi_get_or_create_shorteventid(starting_events)? + .multi_get_or_create_shorteventid(starting_events) + .await .iter() .enumerate() { let bucket: usize = short.try_into()?; - let bucket: usize = validated!(bucket % NUM_BUCKETS)?; + let bucket: usize = validated!(bucket % NUM_BUCKETS); buckets[bucket].insert((short, starting_events[i])); } @@ -85,7 +98,7 @@ impl Service { } let chunk_key: Vec = chunk.iter().map(|(short, _)| short).copied().collect(); - if let Some(cached) = self.get_cached_eventid_authchain(&chunk_key)? { + if let Ok(cached) = self.get_cached_eventid_authchain(&chunk_key).await { trace!("Found cache entry for whole chunk"); full_auth_chain.extend(cached.iter().copied()); hits = hits.saturating_add(1); @@ -96,13 +109,13 @@ impl Service { let mut misses2: usize = 0; let mut chunk_cache = Vec::with_capacity(chunk.len()); for (sevent_id, event_id) in chunk { - if let Some(cached) = self.get_cached_eventid_authchain(&[sevent_id])? { + if let Ok(cached) = self.get_cached_eventid_authchain(&[sevent_id]).await { trace!(?event_id, "Found cache entry for event"); chunk_cache.extend(cached.iter().copied()); hits2 = hits2.saturating_add(1); } else { - let auth_chain = self.get_auth_chain_inner(room_id, event_id)?; - self.cache_auth_chain(vec![sevent_id], &auth_chain)?; + let auth_chain = self.get_auth_chain_inner(room_id, event_id).await?; + self.cache_auth_chain(vec![sevent_id], &auth_chain); chunk_cache.extend(auth_chain.iter()); misses2 = misses2.saturating_add(1); debug!( @@ -117,7 +130,7 @@ impl Service { chunk_cache.sort_unstable(); chunk_cache.dedup(); - self.cache_auth_chain_vec(chunk_key, &chunk_cache)?; + self.cache_auth_chain_vec(chunk_key, &chunk_cache); full_auth_chain.extend(chunk_cache.iter()); misses = misses.saturating_add(1); debug!( @@ -143,24 +156,29 @@ impl Service { } #[tracing::instrument(skip(self, room_id))] - fn get_auth_chain_inner(&self, room_id: &RoomId, event_id: &EventId) -> Result> { + async fn get_auth_chain_inner(&self, room_id: &RoomId, event_id: &EventId) -> Result> { let mut todo = vec![Arc::from(event_id)]; let mut found = HashSet::new(); while let Some(event_id) = todo.pop() { trace!(?event_id, "processing auth event"); - match self.services.timeline.get_pdu(&event_id) { - Ok(Some(pdu)) => { + match self.services.timeline.get_pdu(&event_id).await { + Err(e) => debug_error!(?event_id, ?e, "Could not find pdu mentioned in auth events"), + Ok(pdu) => { if pdu.room_id != room_id { return Err!(Request(Forbidden( - "auth event {event_id:?} for incorrect room {} which is not {}", + "auth event {event_id:?} for incorrect room {} which is not {room_id}", pdu.room_id, - room_id ))); } + for auth_event in &pdu.auth_events { - let sauthevent = self.services.short.get_or_create_shorteventid(auth_event)?; + let sauthevent = self + .services + .short + .get_or_create_shorteventid(auth_event) + .await; if found.insert(sauthevent) { trace!(?event_id, ?auth_event, "adding auth event to processing queue"); @@ -168,32 +186,27 @@ impl Service { } } }, - Ok(None) => { - warn!(?event_id, "Could not find pdu mentioned in auth events"); - }, - Err(error) => { - error!(?event_id, ?error, "Could not load event in auth chain"); - }, } } Ok(found) } - pub fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result>> { - self.db.get_cached_eventid_authchain(key) + #[inline] + pub async fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result> { + self.db.get_cached_eventid_authchain(key).await } #[tracing::instrument(skip(self), level = "debug")] - pub fn cache_auth_chain(&self, key: Vec, auth_chain: &HashSet) -> Result<()> { - self.db - .cache_auth_chain(key, auth_chain.iter().copied().collect::>()) + pub fn cache_auth_chain(&self, key: Vec, auth_chain: &HashSet) { + let val = auth_chain.iter().copied().collect::>(); + self.db.cache_auth_chain(key, val); } #[tracing::instrument(skip(self), level = "debug")] - pub fn cache_auth_chain_vec(&self, key: Vec, auth_chain: &Vec) -> Result<()> { - self.db - .cache_auth_chain(key, auth_chain.iter().copied().collect::>()) + pub fn cache_auth_chain_vec(&self, key: Vec, auth_chain: &Vec) { + let val = auth_chain.iter().copied().collect::>(); + self.db.cache_auth_chain(key, val); } pub fn get_cache_usage(&self) -> (usize, usize) { diff --git a/src/service/rooms/directory/data.rs b/src/service/rooms/directory/data.rs deleted file mode 100644 index 713ee0576..000000000 --- a/src/service/rooms/directory/data.rs +++ /dev/null @@ -1,39 +0,0 @@ -use std::sync::Arc; - -use conduit::{utils, Error, Result}; -use database::{Database, Map}; -use ruma::{OwnedRoomId, RoomId}; - -pub(super) struct Data { - publicroomids: Arc, -} - -impl Data { - pub(super) fn new(db: &Arc) -> Self { - Self { - publicroomids: db["publicroomids"].clone(), - } - } - - pub(super) fn set_public(&self, room_id: &RoomId) -> Result<()> { - self.publicroomids.insert(room_id.as_bytes(), &[]) - } - - pub(super) fn set_not_public(&self, room_id: &RoomId) -> Result<()> { - self.publicroomids.remove(room_id.as_bytes()) - } - - pub(super) fn is_public_room(&self, room_id: &RoomId) -> Result { - Ok(self.publicroomids.get(room_id.as_bytes())?.is_some()) - } - - pub(super) fn public_rooms<'a>(&'a self) -> Box> + 'a> { - Box::new(self.publicroomids.iter().map(|(bytes, _)| { - RoomId::parse( - utils::string_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Room ID in publicroomids is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("Room ID in publicroomids is invalid.")) - })) - } -} diff --git a/src/service/rooms/directory/mod.rs b/src/service/rooms/directory/mod.rs index 706e6c2e5..5666a91a7 100644 --- a/src/service/rooms/directory/mod.rs +++ b/src/service/rooms/directory/mod.rs @@ -1,36 +1,44 @@ -mod data; - use std::sync::Arc; -use conduit::Result; -use ruma::{OwnedRoomId, RoomId}; - -use self::data::Data; +use conduit::{implement, utils::stream::TryIgnore, Result}; +use database::{Ignore, Map}; +use futures::{Stream, StreamExt}; +use ruma::RoomId; pub struct Service { db: Data, } +struct Data { + publicroomids: Arc, +} + impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { - db: Data::new(args.db), + db: Data { + publicroomids: args.db["publicroomids"].clone(), + }, })) } fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } -impl Service { - #[tracing::instrument(skip(self), level = "debug")] - pub fn set_public(&self, room_id: &RoomId) -> Result<()> { self.db.set_public(room_id) } +#[implement(Service)] +pub fn set_public(&self, room_id: &RoomId) { self.db.publicroomids.insert(room_id.as_bytes(), &[]); } - #[tracing::instrument(skip(self), level = "debug")] - pub fn set_not_public(&self, room_id: &RoomId) -> Result<()> { self.db.set_not_public(room_id) } +#[implement(Service)] +pub fn set_not_public(&self, room_id: &RoomId) { self.db.publicroomids.remove(room_id.as_bytes()); } - #[tracing::instrument(skip(self), level = "debug")] - pub fn is_public_room(&self, room_id: &RoomId) -> Result { self.db.is_public_room(room_id) } +#[implement(Service)] +pub async fn is_public_room(&self, room_id: &RoomId) -> bool { self.db.publicroomids.get(room_id).await.is_ok() } - #[tracing::instrument(skip(self), level = "debug")] - pub fn public_rooms(&self) -> impl Iterator> + '_ { self.db.public_rooms() } +#[implement(Service)] +pub fn public_rooms(&self) -> impl Stream + Send { + self.db + .publicroomids + .keys() + .ignore_err() + .map(|(room_id, _): (&RoomId, Ignore)| room_id) } diff --git a/src/service/rooms/event_handler/mod.rs b/src/service/rooms/event_handler/mod.rs index bee986deb..4708a86cb 100644 --- a/src/service/rooms/event_handler/mod.rs +++ b/src/service/rooms/event_handler/mod.rs @@ -1,19 +1,21 @@ mod parse_incoming_pdu; use std::{ + borrow::Borrow, collections::{hash_map, BTreeMap, HashMap, HashSet}, fmt::Write, - pin::Pin, sync::{Arc, RwLock as StdRwLock}, time::Instant, }; use conduit::{ - debug, debug_error, debug_info, err, error, info, pdu, trace, - utils::{math::continue_exponential_backoff_secs, MutexMap}, - warn, Error, PduEvent, Result, + debug, debug_error, debug_info, debug_warn, err, info, pdu, + result::LogErr, + trace, + utils::{math::continue_exponential_backoff_secs, IterStream, MutexMap}, + warn, Err, Error, PduEvent, Result, }; -use futures_util::Future; +use futures::{future, future::ready, FutureExt, StreamExt, TryFutureExt}; use ruma::{ api::{ client::error::ErrorKind, @@ -27,7 +29,7 @@ use ruma::{ }, int, serde::Base64, - state_res::{self, RoomVersion, StateMap}, + state_res::{self, EventTypeExt, RoomVersion, StateMap}, uint, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, OwnedRoomId, OwnedUserId, RoomId, RoomVersionId, ServerName, }; @@ -60,14 +62,6 @@ struct Services { type RoomMutexMap = MutexMap; type HandleTimeMap = HashMap; -// We use some AsyncRecursiveType hacks here so we can call async funtion -// recursively. -type AsyncRecursiveType<'a, T> = Pin + 'a + Send>>; -type AsyncRecursiveCanonicalJsonVec<'a> = - AsyncRecursiveType<'a, Vec<(Arc, Option>)>>; -type AsyncRecursiveCanonicalJsonResult<'a> = - AsyncRecursiveType<'a, Result<(Arc, BTreeMap)>>; - impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { @@ -142,17 +136,17 @@ impl Service { pub_key_map: &'a RwLock>>, ) -> Result>> { // 1. Skip the PDU if we already have it as a timeline event - if let Some(pdu_id) = self.services.timeline.get_pdu_id(event_id)? { + if let Ok(pdu_id) = self.services.timeline.get_pdu_id(event_id).await { return Ok(Some(pdu_id.to_vec())); } // 1.1 Check the server is in the room - if !self.services.metadata.exists(room_id)? { + if !self.services.metadata.exists(room_id).await { return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server")); } // 1.2 Check if the room is disabled - if self.services.metadata.is_disabled(room_id)? { + if self.services.metadata.is_disabled(room_id).await { return Err(Error::BadRequest( ErrorKind::forbidden(), "Federation of this room is currently disabled on this server.", @@ -160,7 +154,7 @@ impl Service { } // 1.3.1 Check room ACL on origin field/server - self.acl_check(origin, room_id)?; + self.acl_check(origin, room_id).await?; // 1.3.2 Check room ACL on sender's server name let sender: OwnedUserId = serde_json::from_value( @@ -172,26 +166,23 @@ impl Service { ) .map_err(|_| Error::BadRequest(ErrorKind::BadJson, "User ID in sender is invalid"))?; - self.acl_check(sender.server_name(), room_id)?; + self.acl_check(sender.server_name(), room_id).await?; // Fetch create event let create_event = self .services .state_accessor - .room_state_get(room_id, &StateEventType::RoomCreate, "")? - .ok_or_else(|| Error::bad_database("Failed to find create event in db."))?; + .room_state_get(room_id, &StateEventType::RoomCreate, "") + .await?; // Procure the room version let room_version_id = Self::get_room_version_id(&create_event)?; - let first_pdu_in_room = self - .services - .timeline - .first_pdu_in_room(room_id)? - .ok_or_else(|| Error::bad_database("Failed to find first pdu in db."))?; + let first_pdu_in_room = self.services.timeline.first_pdu_in_room(room_id).await?; let (incoming_pdu, val) = self .handle_outlier_pdu(origin, &create_event, event_id, room_id, value, false, pub_key_map) + .boxed() .await?; Self::check_room_id(room_id, &incoming_pdu)?; @@ -235,7 +226,7 @@ impl Service { { Ok(()) => continue, Err(e) => { - warn!("Prev event {} failed: {}", prev_id, e); + warn!("Prev event {prev_id} failed: {e}"); match self .services .globals @@ -287,7 +278,7 @@ impl Service { create_event: &Arc, first_pdu_in_room: &Arc, prev_id: &EventId, ) -> Result<()> { // Check for disabled again because it might have changed - if self.services.metadata.is_disabled(room_id)? { + if self.services.metadata.is_disabled(room_id).await { debug!( "Federaton of room {room_id} is currently disabled on this server. Request by origin {origin} and \ event ID {event_id}" @@ -349,149 +340,153 @@ impl Service { } #[allow(clippy::too_many_arguments)] - fn handle_outlier_pdu<'a>( - &'a self, origin: &'a ServerName, create_event: &'a PduEvent, event_id: &'a EventId, room_id: &'a RoomId, + async fn handle_outlier_pdu<'a>( + &self, origin: &'a ServerName, create_event: &'a PduEvent, event_id: &'a EventId, room_id: &'a RoomId, mut value: BTreeMap, auth_events_known: bool, pub_key_map: &'a RwLock>>, - ) -> AsyncRecursiveCanonicalJsonResult<'a> { - Box::pin(async move { - // 1. Remove unsigned field - value.remove("unsigned"); + ) -> Result<(Arc, BTreeMap)> { + // 1. Remove unsigned field + value.remove("unsigned"); - // TODO: For RoomVersion6 we must check that Raw<..> is canonical do we anywhere?: https://matrix.org/docs/spec/rooms/v6#canonical-json + // TODO: For RoomVersion6 we must check that Raw<..> is canonical do we anywhere?: https://matrix.org/docs/spec/rooms/v6#canonical-json - // 2. Check signatures, otherwise drop - // 3. check content hash, redact if doesn't match - let room_version_id = Self::get_room_version_id(create_event)?; + // 2. Check signatures, otherwise drop + // 3. check content hash, redact if doesn't match + let room_version_id = Self::get_room_version_id(create_event)?; - let guard = pub_key_map.read().await; - let mut val = match ruma::signatures::verify_event(&guard, &value, &room_version_id) { - Err(e) => { - // Drop - warn!("Dropping bad event {}: {}", event_id, e,); - return Err(Error::BadRequest(ErrorKind::InvalidParam, "Signature verification failed")); - }, - Ok(ruma::signatures::Verified::Signatures) => { - // Redact - debug_info!("Calculated hash does not match (redaction): {event_id}"); - let Ok(obj) = ruma::canonical_json::redact(value, &room_version_id, None) else { - return Err(Error::BadRequest(ErrorKind::InvalidParam, "Redaction failed")); - }; + let guard = pub_key_map.read().await; + let mut val = match ruma::signatures::verify_event(&guard, &value, &room_version_id) { + Err(e) => { + // Drop + warn!("Dropping bad event {event_id}: {e}"); + return Err!(Request(InvalidParam("Signature verification failed"))); + }, + Ok(ruma::signatures::Verified::Signatures) => { + // Redact + debug_info!("Calculated hash does not match (redaction): {event_id}"); + let Ok(obj) = ruma::canonical_json::redact(value, &room_version_id, None) else { + return Err!(Request(InvalidParam("Redaction failed"))); + }; - // Skip the PDU if it is redacted and we already have it as an outlier event - if self.services.timeline.get_pdu_json(event_id)?.is_some() { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Event was redacted and we already knew about it", - )); - } + // Skip the PDU if it is redacted and we already have it as an outlier event + if self.services.timeline.get_pdu_json(event_id).await.is_ok() { + return Err!(Request(InvalidParam("Event was redacted and we already knew about it"))); + } - obj - }, - Ok(ruma::signatures::Verified::All) => value, - }; + obj + }, + Ok(ruma::signatures::Verified::All) => value, + }; - drop(guard); + drop(guard); - // Now that we have checked the signature and hashes we can add the eventID and - // convert to our PduEvent type - val.insert("event_id".to_owned(), CanonicalJsonValue::String(event_id.as_str().to_owned())); - let incoming_pdu = serde_json::from_value::( - serde_json::to_value(&val).expect("CanonicalJsonObj is a valid JsonValue"), - ) - .map_err(|_| Error::bad_database("Event is not a valid PDU."))?; + // Now that we have checked the signature and hashes we can add the eventID and + // convert to our PduEvent type + val.insert("event_id".to_owned(), CanonicalJsonValue::String(event_id.as_str().to_owned())); + let incoming_pdu = serde_json::from_value::( + serde_json::to_value(&val).expect("CanonicalJsonObj is a valid JsonValue"), + ) + .map_err(|_| Error::bad_database("Event is not a valid PDU."))?; - Self::check_room_id(room_id, &incoming_pdu)?; + Self::check_room_id(room_id, &incoming_pdu)?; - if !auth_events_known { - // 4. fetch any missing auth events doing all checks listed here starting at 1. - // These are not timeline events - // 5. Reject "due to auth events" if can't get all the auth events or some of - // the auth events are also rejected "due to auth events" - // NOTE: Step 5 is not applied anymore because it failed too often - debug!("Fetching auth events"); + if !auth_events_known { + // 4. fetch any missing auth events doing all checks listed here starting at 1. + // These are not timeline events + // 5. Reject "due to auth events" if can't get all the auth events or some of + // the auth events are also rejected "due to auth events" + // NOTE: Step 5 is not applied anymore because it failed too often + debug!("Fetching auth events"); + Box::pin( self.fetch_and_handle_outliers( origin, &incoming_pdu .auth_events .iter() .map(|x| Arc::from(&**x)) - .collect::>(), + .collect::>>(), create_event, room_id, &room_version_id, pub_key_map, - ) - .await; - } + ), + ) + .await; + } - // 6. Reject "due to auth events" if the event doesn't pass auth based on the - // auth events - debug!("Checking based on auth events"); - // Build map of auth events - let mut auth_events = HashMap::with_capacity(incoming_pdu.auth_events.len()); - for id in &incoming_pdu.auth_events { - let Some(auth_event) = self.services.timeline.get_pdu(id)? else { - warn!("Could not find auth event {}", id); - continue; - }; + // 6. Reject "due to auth events" if the event doesn't pass auth based on the + // auth events + debug!("Checking based on auth events"); + // Build map of auth events + let mut auth_events = HashMap::with_capacity(incoming_pdu.auth_events.len()); + for id in &incoming_pdu.auth_events { + let Ok(auth_event) = self.services.timeline.get_pdu(id).await else { + warn!("Could not find auth event {id}"); + continue; + }; - Self::check_room_id(room_id, &auth_event)?; - - match auth_events.entry(( - auth_event.kind.to_string().into(), - auth_event - .state_key - .clone() - .expect("all auth events have state keys"), - )) { - hash_map::Entry::Vacant(v) => { - v.insert(auth_event); - }, - hash_map::Entry::Occupied(_) => { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Auth event's type and state_key combination exists multiple times.", - )); - }, - } + Self::check_room_id(room_id, &auth_event)?; + + match auth_events.entry(( + auth_event.kind.to_string().into(), + auth_event + .state_key + .clone() + .expect("all auth events have state keys"), + )) { + hash_map::Entry::Vacant(v) => { + v.insert(auth_event); + }, + hash_map::Entry::Occupied(_) => { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Auth event's type and state_key combination exists multiple times.", + )); + }, } + } - // The original create event must be in the auth events - if !matches!( - auth_events - .get(&(StateEventType::RoomCreate, String::new())) - .map(AsRef::as_ref), - Some(_) | None - ) { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Incoming event refers to wrong create event.", - )); - } + // The original create event must be in the auth events + if !matches!( + auth_events + .get(&(StateEventType::RoomCreate, String::new())) + .map(AsRef::as_ref), + Some(_) | None + ) { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Incoming event refers to wrong create event.", + )); + } - if !state_res::event_auth::auth_check( - &Self::to_room_version(&room_version_id), - &incoming_pdu, - None::, // TODO: third party invite - |k, s| auth_events.get(&(k.to_string().into(), s.to_owned())), - ) - .map_err(|_e| Error::BadRequest(ErrorKind::forbidden(), "Auth check failed"))? - { - return Err(Error::BadRequest(ErrorKind::forbidden(), "Auth check failed")); - } + let state_fetch = |ty: &'static StateEventType, sk: &str| { + let key = ty.with_state_key(sk); + ready(auth_events.get(&key)) + }; - trace!("Validation successful."); + let auth_check = state_res::event_auth::auth_check( + &Self::to_room_version(&room_version_id), + &incoming_pdu, + None, // TODO: third party invite + state_fetch, + ) + .await + .map_err(|e| err!(Request(Forbidden("Auth check failed: {e:?}"))))?; - // 7. Persist the event as an outlier. - self.services - .outlier - .add_pdu_outlier(&incoming_pdu.event_id, &val)?; + if !auth_check { + return Err!(Request(Forbidden("Auth check failed"))); + } - trace!("Added pdu as outlier."); + trace!("Validation successful."); - Ok((Arc::new(incoming_pdu), val)) - }) + // 7. Persist the event as an outlier. + self.services + .outlier + .add_pdu_outlier(&incoming_pdu.event_id, &val); + + trace!("Added pdu as outlier."); + + Ok((Arc::new(incoming_pdu), val)) } pub async fn upgrade_outlier_to_timeline_pdu( @@ -499,16 +494,22 @@ impl Service { origin: &ServerName, room_id: &RoomId, pub_key_map: &RwLock>>, ) -> Result>> { // Skip the PDU if we already have it as a timeline event - if let Ok(Some(pduid)) = self.services.timeline.get_pdu_id(&incoming_pdu.event_id) { + if let Ok(pduid) = self + .services + .timeline + .get_pdu_id(&incoming_pdu.event_id) + .await + { return Ok(Some(pduid.to_vec())); } if self .services .pdu_metadata - .is_event_soft_failed(&incoming_pdu.event_id)? + .is_event_soft_failed(&incoming_pdu.event_id) + .await { - return Err(Error::BadRequest(ErrorKind::InvalidParam, "Event has been soft failed")); + return Err!(Request(InvalidParam("Event has been soft failed"))); } debug!("Upgrading to timeline pdu"); @@ -545,57 +546,69 @@ impl Service { debug!("Performing auth check"); // 11. Check the auth of the event passes based on the state of the event - let check_result = state_res::event_auth::auth_check( + let state_fetch_state = &state_at_incoming_event; + let state_fetch = |k: &'static StateEventType, s: String| async move { + let shortstatekey = self.services.short.get_shortstatekey(k, &s).await.ok()?; + + let event_id = state_fetch_state.get(&shortstatekey)?; + self.services.timeline.get_pdu(event_id).await.ok() + }; + + let auth_check = state_res::event_auth::auth_check( &room_version, &incoming_pdu, - None::, // TODO: third party invite - |k, s| { - self.services - .short - .get_shortstatekey(&k.to_string().into(), s) - .ok() - .flatten() - .and_then(|shortstatekey| state_at_incoming_event.get(&shortstatekey)) - .and_then(|event_id| self.services.timeline.get_pdu(event_id).ok().flatten()) - }, + None, // TODO: third party invite + |k, s| state_fetch(k, s.to_owned()), ) - .map_err(|_e| Error::BadRequest(ErrorKind::forbidden(), "Auth check failed."))?; + .await + .map_err(|e| err!(Request(Forbidden("Auth check failed: {e:?}"))))?; - if !check_result { - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "Event has failed auth check with state at the event.", - )); + if !auth_check { + return Err!(Request(Forbidden("Event has failed auth check with state at the event."))); } debug!("Gathering auth events"); - let auth_events = self.services.state.get_auth_events( - room_id, - &incoming_pdu.kind, - &incoming_pdu.sender, - incoming_pdu.state_key.as_deref(), - &incoming_pdu.content, - )?; + let auth_events = self + .services + .state + .get_auth_events( + room_id, + &incoming_pdu.kind, + &incoming_pdu.sender, + incoming_pdu.state_key.as_deref(), + &incoming_pdu.content, + ) + .await?; + + let state_fetch = |k: &'static StateEventType, s: &str| { + let key = k.with_state_key(s); + ready(auth_events.get(&key).cloned()) + }; + + let auth_check = state_res::event_auth::auth_check( + &room_version, + &incoming_pdu, + None, // third-party invite + state_fetch, + ) + .await + .map_err(|e| err!(Request(Forbidden("Auth check failed: {e:?}"))))?; // Soft fail check before doing state res debug!("Performing soft-fail check"); let soft_fail = { use RoomVersionId::*; - !state_res::event_auth::auth_check(&room_version, &incoming_pdu, None::, |k, s| { - auth_events.get(&(k.clone(), s.to_owned())) - }) - .map_err(|_e| Error::BadRequest(ErrorKind::forbidden(), "Auth check failed."))? + !auth_check || incoming_pdu.kind == TimelineEventType::RoomRedaction && match room_version_id { V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => { if let Some(redact_id) = &incoming_pdu.redacts { - !self.services.state_accessor.user_can_redact( - redact_id, - &incoming_pdu.sender, - &incoming_pdu.room_id, - true, - )? + !self + .services + .state_accessor + .user_can_redact(redact_id, &incoming_pdu.sender, &incoming_pdu.room_id, true) + .await? } else { false } @@ -605,12 +618,11 @@ impl Service { .map_err(|_| Error::bad_database("Invalid content in redaction pdu."))?; if let Some(redact_id) = &content.redacts { - !self.services.state_accessor.user_can_redact( - redact_id, - &incoming_pdu.sender, - &incoming_pdu.room_id, - true, - )? + !self + .services + .state_accessor + .user_can_redact(redact_id, &incoming_pdu.sender, &incoming_pdu.room_id, true) + .await? } else { false } @@ -627,28 +639,52 @@ impl Service { // Now we calculate the set of extremities this room has after the incoming // event has been applied. We start with the previous extremities (aka leaves) trace!("Calculating extremities"); - let mut extremities = self.services.state.get_forward_extremities(room_id)?; - trace!("Calculated {} extremities", extremities.len()); + let mut extremities: HashSet<_> = self + .services + .state + .get_forward_extremities(room_id) + .map(ToOwned::to_owned) + .collect() + .await; // Remove any forward extremities that are referenced by this incoming event's // prev_events + trace!( + "Calculated {} extremities; checking against {} prev_events", + extremities.len(), + incoming_pdu.prev_events.len() + ); for prev_event in &incoming_pdu.prev_events { - extremities.remove(prev_event); + extremities.remove(&(**prev_event)); } // Only keep those extremities were not referenced yet - extremities.retain(|id| !matches!(self.services.pdu_metadata.is_event_referenced(room_id, id), Ok(true))); + let mut retained = HashSet::new(); + for id in &extremities { + if !self + .services + .pdu_metadata + .is_event_referenced(room_id, id) + .await + { + retained.insert(id.clone()); + } + } + + extremities.retain(|id| retained.contains(id)); debug!("Retained {} extremities. Compressing state", extremities.len()); - let state_ids_compressed = Arc::new( - state_at_incoming_event - .iter() - .map(|(shortstatekey, id)| { - self.services - .state_compressor - .compress_state_event(*shortstatekey, id) - }) - .collect::>()?, - ); + + let mut state_ids_compressed = HashSet::new(); + for (shortstatekey, id) in &state_at_incoming_event { + state_ids_compressed.insert( + self.services + .state_compressor + .compress_state_event(*shortstatekey, id) + .await, + ); + } + + let state_ids_compressed = Arc::new(state_ids_compressed); if incoming_pdu.state_key.is_some() { debug!("Event is a state-event. Deriving new room state"); @@ -659,9 +695,11 @@ impl Service { let shortstatekey = self .services .short - .get_or_create_shortstatekey(&incoming_pdu.kind.to_string().into(), state_key)?; + .get_or_create_shortstatekey(&incoming_pdu.kind.to_string().into(), state_key) + .await; - state_after.insert(shortstatekey, Arc::from(&*incoming_pdu.event_id)); + let event_id = &incoming_pdu.event_id; + state_after.insert(shortstatekey, event_id.clone()); } let new_room_state = self @@ -673,7 +711,8 @@ impl Service { let (sstatehash, new, removed) = self .services .state_compressor - .save_state(room_id, new_room_state)?; + .save_state(room_id, new_room_state) + .await?; self.services .state @@ -698,16 +737,16 @@ impl Service { .await?; // Soft fail, we keep the event as an outlier but don't add it to the timeline - warn!("Event was soft failed: {:?}", incoming_pdu); + warn!("Event was soft failed: {incoming_pdu:?}"); self.services .pdu_metadata - .mark_event_soft_failed(&incoming_pdu.event_id)?; + .mark_event_soft_failed(&incoming_pdu.event_id); return Err(Error::BadRequest(ErrorKind::InvalidParam, "Event has been soft failed")); } trace!("Appending pdu to timeline"); - extremities.insert(incoming_pdu.event_id.clone()); + extremities.insert(incoming_pdu.event_id.clone().into()); // Now that the event has passed all auth it is added into the timeline. // We use the `state_at_event` instead of `state_after` so we accurately @@ -718,7 +757,7 @@ impl Service { .append_incoming_pdu( &incoming_pdu, val, - extremities.iter().map(|e| (**e).to_owned()).collect(), + extremities.into_iter().collect(), state_ids_compressed, soft_fail, &state_lock, @@ -735,6 +774,7 @@ impl Service { Ok(pdu_id) } + #[tracing::instrument(skip_all, name = "resolve")] pub async fn resolve_state( &self, room_id: &RoomId, room_version_id: &RoomVersionId, incoming_state: HashMap>, ) -> Result>> { @@ -742,8 +782,9 @@ impl Service { let current_sstatehash = self .services .state - .get_room_shortstatehash(room_id)? - .expect("every room has state"); + .get_room_shortstatehash(room_id) + .await + .map_err(|e| err!(Database(error!("No state for {room_id:?}: {e:?}"))))?; let current_state_ids = self .services @@ -752,70 +793,69 @@ impl Service { .await?; let fork_states = [current_state_ids, incoming_state]; - let mut auth_chain_sets = Vec::with_capacity(fork_states.len()); for state in &fork_states { - auth_chain_sets.push( - self.services - .auth_chain - .event_ids_iter(room_id, state.iter().map(|(_, id)| id.clone()).collect()) - .await? - .collect(), - ); + let starting_events: Vec<&EventId> = state.values().map(Borrow::borrow).collect(); + + let auth_chain: HashSet> = self + .services + .auth_chain + .get_event_ids(room_id, &starting_events) + .await? + .into_iter() + .collect(); + + auth_chain_sets.push(auth_chain); } debug!("Loading fork states"); - let fork_states: Vec<_> = fork_states + let fork_states: Vec>> = fork_states .into_iter() - .map(|map| { - map.into_iter() + .stream() + .then(|fork_state| { + fork_state + .into_iter() + .stream() .filter_map(|(k, id)| { self.services .short .get_statekey_from_short(k) - .map(|(ty, st_key)| ((ty.to_string().into(), st_key), id)) - .ok() + .map_ok_or_else(|_| None, move |(ty, st_key)| Some(((ty, st_key), id))) }) - .collect::>() + .collect() }) - .collect(); - - let lock = self.services.globals.stateres_mutex.lock(); + .collect() + .boxed() + .await; debug!("Resolving state"); - let state_resolve = state_res::resolve(room_version_id, &fork_states, auth_chain_sets, |id| { - let res = self.services.timeline.get_pdu(id); - if let Err(e) = &res { - error!("Failed to fetch event: {}", e); - } - res.ok().flatten() - }); + let lock = self.services.globals.stateres_mutex.lock(); - let state = match state_resolve { - Ok(new_state) => new_state, - Err(e) => { - error!("State resolution failed: {}", e); - return Err(Error::bad_database( - "State resolution failed, either an event could not be found or deserialization", - )); - }, - }; + let event_fetch = |event_id| self.event_fetch(event_id); + let event_exists = |event_id| self.event_exists(event_id); + let state = state_res::resolve(room_version_id, &fork_states, &auth_chain_sets, &event_fetch, &event_exists) + .await + .map_err(|e| err!(Database(error!("State resolution failed: {e:?}"))))?; drop(lock); debug!("State resolution done. Compressing state"); - let new_room_state = state - .into_iter() - .map(|((event_type, state_key), event_id)| { - let shortstatekey = self - .services - .short - .get_or_create_shortstatekey(&event_type.to_string().into(), &state_key)?; - self.services - .state_compressor - .compress_state_event(shortstatekey, &event_id) - }) - .collect::>()?; + let mut new_room_state = HashSet::new(); + for ((event_type, state_key), event_id) in state { + let shortstatekey = self + .services + .short + .get_or_create_shortstatekey(&event_type.to_string().into(), &state_key) + .await; + + let compressed = self + .services + .state_compressor + .compress_state_event(shortstatekey, &event_id) + .await; + + new_room_state.insert(compressed); + } Ok(Arc::new(new_room_state)) } @@ -827,46 +867,47 @@ impl Service { &self, incoming_pdu: &Arc, ) -> Result>>> { let prev_event = &*incoming_pdu.prev_events[0]; - let prev_event_sstatehash = self + let Ok(prev_event_sstatehash) = self .services .state_accessor - .pdu_shortstatehash(prev_event)?; - - let state = if let Some(shortstatehash) = prev_event_sstatehash { - Some( - self.services - .state_accessor - .state_full_ids(shortstatehash) - .await, - ) - } else { - None + .pdu_shortstatehash(prev_event) + .await + else { + return Ok(None); }; - if let Some(Ok(mut state)) = state { - debug!("Using cached state"); - let prev_pdu = self - .services - .timeline - .get_pdu(prev_event) - .ok() - .flatten() - .ok_or_else(|| Error::bad_database("Could not find prev event, but we know the state."))?; + let Ok(mut state) = self + .services + .state_accessor + .state_full_ids(prev_event_sstatehash) + .await + .log_err() + else { + return Ok(None); + }; - if let Some(state_key) = &prev_pdu.state_key { - let shortstatekey = self - .services - .short - .get_or_create_shortstatekey(&prev_pdu.kind.to_string().into(), state_key)?; + debug!("Using cached state"); + let prev_pdu = self + .services + .timeline + .get_pdu(prev_event) + .await + .map_err(|e| err!(Database("Could not find prev event, but we know the state: {e:?}")))?; - state.insert(shortstatekey, Arc::from(prev_event)); - // Now it's the state after the pdu - } + if let Some(state_key) = &prev_pdu.state_key { + let shortstatekey = self + .services + .short + .get_or_create_shortstatekey(&prev_pdu.kind.to_string().into(), state_key) + .await; - return Ok(Some(state)); + state.insert(shortstatekey, Arc::from(prev_event)); + // Now it's the state after the pdu } - Ok(None) + debug_assert!(!state.is_empty(), "should be returning None for empty HashMap result"); + + Ok(Some(state)) } #[tracing::instrument(skip_all, name = "state")] @@ -878,15 +919,16 @@ impl Service { let mut okay = true; for prev_eventid in &incoming_pdu.prev_events { - let Ok(Some(prev_event)) = self.services.timeline.get_pdu(prev_eventid) else { + let Ok(prev_event) = self.services.timeline.get_pdu(prev_eventid).await else { okay = false; break; }; - let Ok(Some(sstatehash)) = self + let Ok(sstatehash) = self .services .state_accessor .pdu_shortstatehash(prev_eventid) + .await else { okay = false; break; @@ -901,79 +943,85 @@ impl Service { let mut fork_states = Vec::with_capacity(extremity_sstatehashes.len()); let mut auth_chain_sets = Vec::with_capacity(extremity_sstatehashes.len()); - for (sstatehash, prev_event) in extremity_sstatehashes { - let mut leaf_state: HashMap<_, _> = self + let Ok(mut leaf_state) = self .services .state_accessor .state_full_ids(sstatehash) - .await?; + .await + else { + continue; + }; if let Some(state_key) = &prev_event.state_key { let shortstatekey = self .services .short - .get_or_create_shortstatekey(&prev_event.kind.to_string().into(), state_key)?; - leaf_state.insert(shortstatekey, Arc::from(&*prev_event.event_id)); + .get_or_create_shortstatekey(&prev_event.kind.to_string().into(), state_key) + .await; + + let event_id = &prev_event.event_id; + leaf_state.insert(shortstatekey, event_id.clone()); // Now it's the state after the pdu } let mut state = StateMap::with_capacity(leaf_state.len()); let mut starting_events = Vec::with_capacity(leaf_state.len()); - - for (k, id) in leaf_state { - if let Ok((ty, st_key)) = self.services.short.get_statekey_from_short(k) { + for (k, id) in &leaf_state { + if let Ok((ty, st_key)) = self + .services + .short + .get_statekey_from_short(*k) + .await + .log_err() + { // FIXME: Undo .to_string().into() when StateMap // is updated to use StateEventType state.insert((ty.to_string().into(), st_key), id.clone()); - } else { - warn!("Failed to get_statekey_from_short."); } - starting_events.push(id); + + starting_events.push(id.borrow()); } - auth_chain_sets.push( - self.services - .auth_chain - .event_ids_iter(room_id, starting_events) - .await? - .collect(), - ); + let auth_chain: HashSet> = self + .services + .auth_chain + .get_event_ids(room_id, &starting_events) + .await? + .into_iter() + .collect(); + auth_chain_sets.push(auth_chain); fork_states.push(state); } let lock = self.services.globals.stateres_mutex.lock(); - let result = state_res::resolve(room_version_id, &fork_states, auth_chain_sets, |id| { - let res = self.services.timeline.get_pdu(id); - if let Err(e) = &res { - error!("Failed to fetch event: {}", e); - } - res.ok().flatten() - }); + + let event_fetch = |event_id| self.event_fetch(event_id); + let event_exists = |event_id| self.event_exists(event_id); + let result = state_res::resolve(room_version_id, &fork_states, &auth_chain_sets, &event_fetch, &event_exists) + .await + .map_err(|e| err!(Database(warn!(?e, "State resolution on prev events failed.")))); + drop(lock); - Ok(match result { - Ok(new_state) => Some( - new_state - .into_iter() - .map(|((event_type, state_key), event_id)| { - let shortstatekey = self - .services - .short - .get_or_create_shortstatekey(&event_type.to_string().into(), &state_key)?; - Ok((shortstatekey, event_id)) - }) - .collect::>()?, - ), - Err(e) => { - warn!( - "State resolution on prev events failed, either an event could not be found or deserialization: {}", - e - ); - None - }, - }) + let Ok(new_state) = result else { + return Ok(None); + }; + + new_state + .iter() + .stream() + .then(|((event_type, state_key), event_id)| { + self.services + .short + .get_or_create_shortstatekey(event_type, state_key) + .map(move |shortstatekey| (shortstatekey, event_id.clone())) + }) + .collect() + .map(Some) + .map(Ok) + .await } /// Call /state_ids to find out what the state at this pdu is. We trust the @@ -985,7 +1033,7 @@ impl Service { pub_key_map: &RwLock>>, event_id: &EventId, ) -> Result>>> { debug!("Fetching state ids"); - match self + let res = self .services .sending .send_federation_request( @@ -996,61 +1044,57 @@ impl Service { }, ) .await - { - Ok(res) => { - debug!("Fetching state events"); - let collect = res - .pdu_ids - .iter() - .map(|x| Arc::from(&**x)) - .collect::>(); - - let state_vec = self - .fetch_and_handle_outliers(origin, &collect, create_event, room_id, room_version_id, pub_key_map) - .await; - - let mut state: HashMap<_, Arc> = HashMap::with_capacity(state_vec.len()); - for (pdu, _) in state_vec { - let state_key = pdu - .state_key - .clone() - .ok_or_else(|| Error::bad_database("Found non-state pdu in state events."))?; + .inspect_err(|e| warn!("Fetching state for event failed: {e}"))?; + + debug!("Fetching state events"); + let collect = res + .pdu_ids + .iter() + .map(|x| Arc::from(&**x)) + .collect::>(); + + let state_vec = self + .fetch_and_handle_outliers(origin, &collect, create_event, room_id, room_version_id, pub_key_map) + .boxed() + .await; - let shortstatekey = self - .services - .short - .get_or_create_shortstatekey(&pdu.kind.to_string().into(), &state_key)?; + let mut state: HashMap<_, Arc> = HashMap::with_capacity(state_vec.len()); + for (pdu, _) in state_vec { + let state_key = pdu + .state_key + .clone() + .ok_or_else(|| Error::bad_database("Found non-state pdu in state events."))?; - match state.entry(shortstatekey) { - hash_map::Entry::Vacant(v) => { - v.insert(Arc::from(&*pdu.event_id)); - }, - hash_map::Entry::Occupied(_) => { - return Err(Error::bad_database( - "State event's type and state_key combination exists multiple times.", - )) - }, - } - } + let shortstatekey = self + .services + .short + .get_or_create_shortstatekey(&pdu.kind.to_string().into(), &state_key) + .await; - // The original create event must still be in the state - let create_shortstatekey = self - .services - .short - .get_shortstatekey(&StateEventType::RoomCreate, "")? - .expect("Room exists"); + match state.entry(shortstatekey) { + hash_map::Entry::Vacant(v) => { + v.insert(Arc::from(&*pdu.event_id)); + }, + hash_map::Entry::Occupied(_) => { + return Err(Error::bad_database( + "State event's type and state_key combination exists multiple times.", + )) + }, + } + } - if state.get(&create_shortstatekey).map(AsRef::as_ref) != Some(&create_event.event_id) { - return Err(Error::bad_database("Incoming event refers to wrong create event.")); - } + // The original create event must still be in the state + let create_shortstatekey = self + .services + .short + .get_shortstatekey(&StateEventType::RoomCreate, "") + .await?; - Ok(Some(state)) - }, - Err(e) => { - warn!("Fetching state for event failed: {}", e); - Err(e) - }, + if state.get(&create_shortstatekey).map(AsRef::as_ref) != Some(&create_event.event_id) { + return Err!(Database("Incoming event refers to wrong create event.")); } + + Ok(Some(state)) } /// Find the event and auth it. Once the event is validated (steps 1 - 8) @@ -1062,191 +1106,196 @@ impl Service { /// b. Look at outlier pdu tree /// c. Ask origin server over federation /// d. TODO: Ask other servers over federation? - pub fn fetch_and_handle_outliers<'a>( - &'a self, origin: &'a ServerName, events: &'a [Arc], create_event: &'a PduEvent, room_id: &'a RoomId, + pub async fn fetch_and_handle_outliers<'a>( + &self, origin: &'a ServerName, events: &'a [Arc], create_event: &'a PduEvent, room_id: &'a RoomId, room_version_id: &'a RoomVersionId, pub_key_map: &'a RwLock>>, - ) -> AsyncRecursiveCanonicalJsonVec<'a> { - Box::pin(async move { - let back_off = |id| async { - match self + ) -> Vec<(Arc, Option>)> { + let back_off = |id| match self + .services + .globals + .bad_event_ratelimiter + .write() + .expect("locked") + .entry(id) + { + hash_map::Entry::Vacant(e) => { + e.insert((Instant::now(), 1)); + }, + hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1.saturating_add(1)), + }; + + let mut events_with_auth_events = Vec::with_capacity(events.len()); + for id in events { + // a. Look in the main timeline (pduid_pdu tree) + // b. Look at outlier pdu tree + // (get_pdu_json checks both) + if let Ok(local_pdu) = self.services.timeline.get_pdu(id).await { + trace!("Found {id} in db"); + events_with_auth_events.push((id, Some(local_pdu), vec![])); + continue; + } + + // c. Ask origin server over federation + // We also handle its auth chain here so we don't get a stack overflow in + // handle_outlier_pdu. + let mut todo_auth_events = vec![Arc::clone(id)]; + let mut events_in_reverse_order = Vec::with_capacity(todo_auth_events.len()); + let mut events_all = HashSet::with_capacity(todo_auth_events.len()); + let mut i: u64 = 0; + while let Some(next_id) = todo_auth_events.pop() { + if let Some((time, tries)) = self .services .globals .bad_event_ratelimiter - .write() + .read() .expect("locked") - .entry(id) + .get(&*next_id) { - hash_map::Entry::Vacant(e) => { - e.insert((Instant::now(), 1)); - }, - hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1.saturating_add(1)), + // Exponential backoff + const MIN_DURATION: u64 = 5 * 60; + const MAX_DURATION: u64 = 60 * 60 * 24; + if continue_exponential_backoff_secs(MIN_DURATION, MAX_DURATION, time.elapsed(), *tries) { + info!("Backing off from {next_id}"); + continue; + } } - }; - let mut events_with_auth_events = Vec::with_capacity(events.len()); - for id in events { - // a. Look in the main timeline (pduid_pdu tree) - // b. Look at outlier pdu tree - // (get_pdu_json checks both) - if let Ok(Some(local_pdu)) = self.services.timeline.get_pdu(id) { - trace!("Found {} in db", id); - events_with_auth_events.push((id, Some(local_pdu), vec![])); + if events_all.contains(&next_id) { continue; } - // c. Ask origin server over federation - // We also handle its auth chain here so we don't get a stack overflow in - // handle_outlier_pdu. - let mut todo_auth_events = vec![Arc::clone(id)]; - let mut events_in_reverse_order = Vec::with_capacity(todo_auth_events.len()); - let mut events_all = HashSet::with_capacity(todo_auth_events.len()); - let mut i: u64 = 0; - while let Some(next_id) = todo_auth_events.pop() { - if let Some((time, tries)) = self - .services - .globals - .bad_event_ratelimiter - .read() - .expect("locked") - .get(&*next_id) - { - // Exponential backoff - const MIN_DURATION: u64 = 5 * 60; - const MAX_DURATION: u64 = 60 * 60 * 24; - if continue_exponential_backoff_secs(MIN_DURATION, MAX_DURATION, time.elapsed(), *tries) { - info!("Backing off from {next_id}"); - continue; - } - } - - if events_all.contains(&next_id) { - continue; - } - - i = i.saturating_add(1); - if i % 100 == 0 { - tokio::task::yield_now().await; - } + i = i.saturating_add(1); + if i % 100 == 0 { + tokio::task::yield_now().await; + } - if let Ok(Some(_)) = self.services.timeline.get_pdu(&next_id) { - trace!("Found {} in db", next_id); - continue; - } + if self.services.timeline.get_pdu(&next_id).await.is_ok() { + trace!("Found {next_id} in db"); + continue; + } - debug!("Fetching {} over federation.", next_id); - match self - .services - .sending - .send_federation_request( - origin, - get_event::v1::Request { - event_id: (*next_id).to_owned(), - }, - ) - .await - { - Ok(res) => { - debug!("Got {} over federation", next_id); - let Ok((calculated_event_id, value)) = - pdu::gen_event_id_canonical_json(&res.pdu, room_version_id) - else { - back_off((*next_id).to_owned()).await; - continue; - }; - - if calculated_event_id != *next_id { - warn!( - "Server didn't return event id we requested: requested: {}, we got {}. Event: {:?}", - next_id, calculated_event_id, &res.pdu - ); - } + debug!("Fetching {next_id} over federation."); + match self + .services + .sending + .send_federation_request( + origin, + get_event::v1::Request { + event_id: (*next_id).to_owned(), + }, + ) + .await + { + Ok(res) => { + debug!("Got {next_id} over federation"); + let Ok((calculated_event_id, value)) = + pdu::gen_event_id_canonical_json(&res.pdu, room_version_id) + else { + back_off((*next_id).to_owned()); + continue; + }; + + if calculated_event_id != *next_id { + warn!( + "Server didn't return event id we requested: requested: {next_id}, we got \ + {calculated_event_id}. Event: {:?}", + &res.pdu + ); + } - if let Some(auth_events) = value.get("auth_events").and_then(|c| c.as_array()) { - for auth_event in auth_events { - if let Ok(auth_event) = serde_json::from_value(auth_event.clone().into()) { - let a: Arc = auth_event; - todo_auth_events.push(a); - } else { - warn!("Auth event id is not valid"); - } + if let Some(auth_events) = value.get("auth_events").and_then(|c| c.as_array()) { + for auth_event in auth_events { + if let Ok(auth_event) = serde_json::from_value(auth_event.clone().into()) { + let a: Arc = auth_event; + todo_auth_events.push(a); + } else { + warn!("Auth event id is not valid"); } - } else { - warn!("Auth event list invalid"); } + } else { + warn!("Auth event list invalid"); + } - events_in_reverse_order.push((next_id.clone(), value)); - events_all.insert(next_id); - }, - Err(e) => { - debug_error!("Failed to fetch event {next_id}: {e}"); - back_off((*next_id).to_owned()).await; - }, - } + events_in_reverse_order.push((next_id.clone(), value)); + events_all.insert(next_id); + }, + Err(e) => { + debug_error!("Failed to fetch event {next_id}: {e}"); + back_off((*next_id).to_owned()); + }, } - events_with_auth_events.push((id, None, events_in_reverse_order)); } + events_with_auth_events.push((id, None, events_in_reverse_order)); + } - // We go through all the signatures we see on the PDUs and their unresolved - // dependencies and fetch the corresponding signing keys - self.services - .server_keys - .fetch_required_signing_keys( - events_with_auth_events - .iter() - .flat_map(|(_id, _local_pdu, events)| events) - .map(|(_event_id, event)| event), + // We go through all the signatures we see on the PDUs and their unresolved + // dependencies and fetch the corresponding signing keys + self.services + .server_keys + .fetch_required_signing_keys( + events_with_auth_events + .iter() + .flat_map(|(_id, _local_pdu, events)| events) + .map(|(_event_id, event)| event), + pub_key_map, + ) + .await + .unwrap_or_else(|e| { + warn!("Could not fetch all signatures for PDUs from {origin}: {e:?}"); + }); + + let mut pdus = Vec::with_capacity(events_with_auth_events.len()); + for (id, local_pdu, events_in_reverse_order) in events_with_auth_events { + // a. Look in the main timeline (pduid_pdu tree) + // b. Look at outlier pdu tree + // (get_pdu_json checks both) + if let Some(local_pdu) = local_pdu { + trace!("Found {id} in db"); + pdus.push((local_pdu.clone(), None)); + } + + for (next_id, value) in events_in_reverse_order.into_iter().rev() { + if let Some((time, tries)) = self + .services + .globals + .bad_event_ratelimiter + .read() + .expect("locked") + .get(&*next_id) + { + // Exponential backoff + const MIN_DURATION: u64 = 5 * 60; + const MAX_DURATION: u64 = 60 * 60 * 24; + if continue_exponential_backoff_secs(MIN_DURATION, MAX_DURATION, time.elapsed(), *tries) { + debug!("Backing off from {next_id}"); + continue; + } + } + + match Box::pin(self.handle_outlier_pdu( + origin, + create_event, + &next_id, + room_id, + value.clone(), + true, pub_key_map, - ) + )) .await - .unwrap_or_else(|e| { - warn!("Could not fetch all signatures for PDUs from {}: {:?}", origin, e); - }); - - let mut pdus = Vec::with_capacity(events_with_auth_events.len()); - for (id, local_pdu, events_in_reverse_order) in events_with_auth_events { - // a. Look in the main timeline (pduid_pdu tree) - // b. Look at outlier pdu tree - // (get_pdu_json checks both) - if let Some(local_pdu) = local_pdu { - trace!("Found {} in db", id); - pdus.push((local_pdu, None)); - } - for (next_id, value) in events_in_reverse_order.iter().rev() { - if let Some((time, tries)) = self - .services - .globals - .bad_event_ratelimiter - .read() - .expect("locked") - .get(&**next_id) - { - // Exponential backoff - const MIN_DURATION: u64 = 5 * 60; - const MAX_DURATION: u64 = 60 * 60 * 24; - if continue_exponential_backoff_secs(MIN_DURATION, MAX_DURATION, time.elapsed(), *tries) { - debug!("Backing off from {next_id}"); - continue; + { + Ok((pdu, json)) => { + if next_id == *id { + pdus.push((pdu, Some(json))); } - } - - match self - .handle_outlier_pdu(origin, create_event, next_id, room_id, value.clone(), true, pub_key_map) - .await - { - Ok((pdu, json)) => { - if next_id == id { - pdus.push((pdu, Some(json))); - } - }, - Err(e) => { - warn!("Authentication of event {} failed: {:?}", next_id, e); - back_off((**next_id).to_owned()).await; - }, - } + }, + Err(e) => { + warn!("Authentication of event {next_id} failed: {e:?}"); + back_off(next_id.into()); + }, } } - pdus - }) + } + pdus } #[allow(clippy::type_complexity)] @@ -1262,16 +1311,12 @@ impl Service { let mut eventid_info = HashMap::new(); let mut todo_outlier_stack: Vec> = initial_set; - let first_pdu_in_room = self - .services - .timeline - .first_pdu_in_room(room_id)? - .ok_or_else(|| Error::bad_database("Failed to find first pdu in db."))?; + let first_pdu_in_room = self.services.timeline.first_pdu_in_room(room_id).await?; let mut amount = 0; while let Some(prev_event_id) = todo_outlier_stack.pop() { - if let Some((pdu, json_opt)) = self + if let Some((pdu, mut json_opt)) = self .fetch_and_handle_outliers( origin, &[prev_event_id.clone()], @@ -1280,28 +1325,29 @@ impl Service { room_version_id, pub_key_map, ) + .boxed() .await .pop() { Self::check_room_id(room_id, &pdu)?; - if amount > self.services.globals.max_fetch_prev_events() { - // Max limit reached - debug!( - "Max prev event limit reached! Limit: {}", - self.services.globals.max_fetch_prev_events() - ); + let limit = self.services.globals.max_fetch_prev_events(); + if amount > limit { + debug_warn!("Max prev event limit reached! Limit: {limit}"); graph.insert(prev_event_id.clone(), HashSet::new()); continue; } - if let Some(json) = json_opt.or_else(|| { - self.services + if json_opt.is_none() { + json_opt = self + .services .outlier .get_outlier_pdu_json(&prev_event_id) - .ok() - .flatten() - }) { + .await + .ok(); + } + + if let Some(json) = json_opt { if pdu.origin_server_ts > first_pdu_in_room.origin_server_ts { amount = amount.saturating_add(1); for prev_prev in &pdu.prev_events { @@ -1327,56 +1373,42 @@ impl Service { } } - let sorted = state_res::lexicographical_topological_sort(&graph, |event_id| { + let event_fetch = |event_id| { + let origin_server_ts = eventid_info + .get(&event_id) + .cloned() + .map_or_else(|| uint!(0), |info| info.0.origin_server_ts); + // This return value is the key used for sorting events, // events are then sorted by power level, time, // and lexically by event_id. - Ok(( - int!(0), - MilliSecondsSinceUnixEpoch( - eventid_info - .get(event_id) - .map_or_else(|| uint!(0), |info| info.0.origin_server_ts), - ), - )) - }) - .map_err(|e| { - error!("Error sorting prev events: {e}"); - Error::bad_database("Error sorting prev events") - })?; + future::ok((int!(0), MilliSecondsSinceUnixEpoch(origin_server_ts))) + }; + + let sorted = state_res::lexicographical_topological_sort(&graph, &event_fetch) + .await + .map_err(|e| err!(Database(error!("Error sorting prev events: {e}"))))?; Ok((sorted, eventid_info)) } /// Returns Ok if the acl allows the server #[tracing::instrument(skip_all)] - pub fn acl_check(&self, server_name: &ServerName, room_id: &RoomId) -> Result<()> { - let acl_event = if let Some(acl) = - self.services - .state_accessor - .room_state_get(room_id, &StateEventType::RoomServerAcl, "")? - { - trace!("ACL event found: {acl:?}"); - acl - } else { - trace!("No ACL event found"); + pub async fn acl_check(&self, server_name: &ServerName, room_id: &RoomId) -> Result<()> { + let Ok(acl_event_content) = self + .services + .state_accessor + .room_state_get_content(room_id, &StateEventType::RoomServerAcl, "") + .await + .map(|c: RoomServerAclEventContent| c) + .inspect(|acl| trace!("ACL content found: {acl:?}")) + .inspect_err(|e| trace!("No ACL content found: {e:?}")) + else { return Ok(()); }; - let acl_event_content: RoomServerAclEventContent = match serde_json::from_str(acl_event.content.get()) { - Ok(content) => { - trace!("Found ACL event contents: {content:?}"); - content - }, - Err(e) => { - warn!("Invalid ACL event: {e}"); - return Ok(()); - }, - }; - if acl_event_content.allow.is_empty() { warn!("Ignoring broken ACL event (allow key is empty)"); - // Ignore broken acl events return Ok(()); } @@ -1384,16 +1416,18 @@ impl Service { trace!("server {server_name} is allowed by ACL"); Ok(()) } else { - debug!("Server {} was denied by room ACL in {}", server_name, room_id); - Err(Error::BadRequest(ErrorKind::forbidden(), "Server was denied by room ACL")) + debug!("Server {server_name} was denied by room ACL in {room_id}"); + Err!(Request(Forbidden("Server was denied by room ACL"))) } } fn check_room_id(room_id: &RoomId, pdu: &PduEvent) -> Result<()> { if pdu.room_id != room_id { - warn!("Found event from room {} in room {}", pdu.room_id, room_id); - return Err(Error::BadRequest(ErrorKind::InvalidParam, "Event has wrong room id")); + return Err!(Request(InvalidParam( + warn!(pdu_event_id = ?pdu.event_id, pdu_room_id = ?pdu.room_id, ?room_id, "Found event from room in room") + ))); } + Ok(()) } @@ -1408,4 +1442,10 @@ impl Service { fn to_room_version(room_version_id: &RoomVersionId) -> RoomVersion { RoomVersion::new(room_version_id).expect("room version is supported") } + + async fn event_exists(&self, event_id: Arc) -> bool { self.services.timeline.pdu_exists(&event_id).await } + + async fn event_fetch(&self, event_id: Arc) -> Option> { + self.services.timeline.get_pdu(&event_id).await.ok() + } } diff --git a/src/service/rooms/event_handler/parse_incoming_pdu.rs b/src/service/rooms/event_handler/parse_incoming_pdu.rs index a7ffe1930..2de3e28ef 100644 --- a/src/service/rooms/event_handler/parse_incoming_pdu.rs +++ b/src/service/rooms/event_handler/parse_incoming_pdu.rs @@ -3,7 +3,9 @@ use ruma::{CanonicalJsonObject, OwnedEventId, OwnedRoomId, RoomId}; use serde_json::value::RawValue as RawJsonValue; impl super::Service { - pub fn parse_incoming_pdu(&self, pdu: &RawJsonValue) -> Result<(OwnedEventId, CanonicalJsonObject, OwnedRoomId)> { + pub async fn parse_incoming_pdu( + &self, pdu: &RawJsonValue, + ) -> Result<(OwnedEventId, CanonicalJsonObject, OwnedRoomId)> { let value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| { debug_warn!("Error parsing incoming event {pdu:#?}"); err!(BadServerResponse("Error parsing incoming event {e:?}")) @@ -14,7 +16,7 @@ impl super::Service { .and_then(|id| RoomId::parse(id.as_str()?).ok()) .ok_or(err!(Request(InvalidParam("Invalid room id in pdu"))))?; - let Ok(room_version_id) = self.services.state.get_room_version(&room_id) else { + let Ok(room_version_id) = self.services.state.get_room_version(&room_id).await else { return Err!("Server is not in room {room_id}"); }; diff --git a/src/service/rooms/lazy_loading/data.rs b/src/service/rooms/lazy_loading/data.rs deleted file mode 100644 index 073d45f56..000000000 --- a/src/service/rooms/lazy_loading/data.rs +++ /dev/null @@ -1,65 +0,0 @@ -use std::sync::Arc; - -use conduit::Result; -use database::{Database, Map}; -use ruma::{DeviceId, RoomId, UserId}; - -pub(super) struct Data { - lazyloadedids: Arc, -} - -impl Data { - pub(super) fn new(db: &Arc) -> Self { - Self { - lazyloadedids: db["lazyloadedids"].clone(), - } - } - - pub(super) fn lazy_load_was_sent_before( - &self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, ll_user: &UserId, - ) -> Result { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(device_id.as_bytes()); - key.push(0xFF); - key.extend_from_slice(room_id.as_bytes()); - key.push(0xFF); - key.extend_from_slice(ll_user.as_bytes()); - Ok(self.lazyloadedids.get(&key)?.is_some()) - } - - pub(super) fn lazy_load_confirm_delivery( - &self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, - confirmed_user_ids: &mut dyn Iterator, - ) -> Result<()> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - prefix.extend_from_slice(device_id.as_bytes()); - prefix.push(0xFF); - prefix.extend_from_slice(room_id.as_bytes()); - prefix.push(0xFF); - - for ll_id in confirmed_user_ids { - let mut key = prefix.clone(); - key.extend_from_slice(ll_id.as_bytes()); - self.lazyloadedids.insert(&key, &[])?; - } - - Ok(()) - } - - pub(super) fn lazy_load_reset(&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId) -> Result<()> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - prefix.extend_from_slice(device_id.as_bytes()); - prefix.push(0xFF); - prefix.extend_from_slice(room_id.as_bytes()); - prefix.push(0xFF); - - for (key, _) in self.lazyloadedids.scan_prefix(prefix) { - self.lazyloadedids.remove(&key)?; - } - - Ok(()) - } -} diff --git a/src/service/rooms/lazy_loading/mod.rs b/src/service/rooms/lazy_loading/mod.rs index 0a9d4cf29..e0816d3f3 100644 --- a/src/service/rooms/lazy_loading/mod.rs +++ b/src/service/rooms/lazy_loading/mod.rs @@ -1,21 +1,26 @@ -mod data; - use std::{ collections::{HashMap, HashSet}, fmt::Write, sync::{Arc, Mutex}, }; -use conduit::{PduCount, Result}; +use conduit::{ + implement, + utils::{stream::TryIgnore, ReadyExt}, + PduCount, Result, +}; +use database::{Interfix, Map}; use ruma::{DeviceId, OwnedDeviceId, OwnedRoomId, OwnedUserId, RoomId, UserId}; -use self::data::Data; - pub struct Service { - pub lazy_load_waiting: Mutex, + lazy_load_waiting: Mutex, db: Data, } +struct Data { + lazyloadedids: Arc, +} + type LazyLoadWaiting = HashMap; type LazyLoadWaitingKey = (OwnedUserId, OwnedDeviceId, OwnedRoomId, PduCount); type LazyLoadWaitingVal = HashSet; @@ -23,8 +28,10 @@ type LazyLoadWaitingVal = HashSet; impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { - lazy_load_waiting: Mutex::new(HashMap::new()), - db: Data::new(args.db), + lazy_load_waiting: LazyLoadWaiting::new().into(), + db: Data { + lazyloadedids: args.db["lazyloadedids"].clone(), + }, })) } @@ -40,47 +47,60 @@ impl crate::Service for Service { fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } -impl Service { - #[tracing::instrument(skip(self), level = "debug")] - pub fn lazy_load_was_sent_before( - &self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, ll_user: &UserId, - ) -> Result { - self.db - .lazy_load_was_sent_before(user_id, device_id, room_id, ll_user) - } +#[implement(Service)] +#[tracing::instrument(skip(self), level = "debug")] +#[inline] +pub async fn lazy_load_was_sent_before( + &self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, ll_user: &UserId, +) -> bool { + let key = (user_id, device_id, room_id, ll_user); + self.db.lazyloadedids.qry(&key).await.is_ok() +} - #[tracing::instrument(skip(self), level = "debug")] - pub async fn lazy_load_mark_sent( - &self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, lazy_load: HashSet, - count: PduCount, - ) { - self.lazy_load_waiting - .lock() - .expect("locked") - .insert((user_id.to_owned(), device_id.to_owned(), room_id.to_owned(), count), lazy_load); - } +#[implement(Service)] +#[tracing::instrument(skip(self), level = "debug")] +pub fn lazy_load_mark_sent( + &self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, lazy_load: HashSet, count: PduCount, +) { + let key = (user_id.to_owned(), device_id.to_owned(), room_id.to_owned(), count); - #[tracing::instrument(skip(self), level = "debug")] - pub async fn lazy_load_confirm_delivery( - &self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, since: PduCount, - ) -> Result<()> { - if let Some(user_ids) = self.lazy_load_waiting.lock().expect("locked").remove(&( - user_id.to_owned(), - device_id.to_owned(), - room_id.to_owned(), - since, - )) { - self.db - .lazy_load_confirm_delivery(user_id, device_id, room_id, &mut user_ids.iter().map(|u| &**u))?; - } else { - // Ignore - } + self.lazy_load_waiting + .lock() + .expect("locked") + .insert(key, lazy_load); +} - Ok(()) - } +#[implement(Service)] +#[tracing::instrument(skip(self), level = "debug")] +pub fn lazy_load_confirm_delivery(&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, since: PduCount) { + let key = (user_id.to_owned(), device_id.to_owned(), room_id.to_owned(), since); + + let Some(user_ids) = self.lazy_load_waiting.lock().expect("locked").remove(&key) else { + return; + }; - #[tracing::instrument(skip(self), level = "debug")] - pub fn lazy_load_reset(&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId) -> Result<()> { - self.db.lazy_load_reset(user_id, device_id, room_id) + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xFF); + prefix.extend_from_slice(device_id.as_bytes()); + prefix.push(0xFF); + prefix.extend_from_slice(room_id.as_bytes()); + prefix.push(0xFF); + + for ll_id in &user_ids { + let mut key = prefix.clone(); + key.extend_from_slice(ll_id.as_bytes()); + self.db.lazyloadedids.insert(&key, &[]); } } + +#[implement(Service)] +#[tracing::instrument(skip(self), level = "debug")] +pub async fn lazy_load_reset(&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId) { + let prefix = (user_id, device_id, room_id, Interfix); + self.db + .lazyloadedids + .keys_raw_prefix(&prefix) + .ignore_err() + .ready_for_each(|key| self.db.lazyloadedids.remove(key)) + .await; +} diff --git a/src/service/rooms/metadata/data.rs b/src/service/rooms/metadata/data.rs deleted file mode 100644 index efe681b1b..000000000 --- a/src/service/rooms/metadata/data.rs +++ /dev/null @@ -1,110 +0,0 @@ -use std::sync::Arc; - -use conduit::{error, utils, Error, Result}; -use database::Map; -use ruma::{OwnedRoomId, RoomId}; - -use crate::{rooms, Dep}; - -pub(super) struct Data { - disabledroomids: Arc, - bannedroomids: Arc, - roomid_shortroomid: Arc, - pduid_pdu: Arc, - services: Services, -} - -struct Services { - short: Dep, -} - -impl Data { - pub(super) fn new(args: &crate::Args<'_>) -> Self { - let db = &args.db; - Self { - disabledroomids: db["disabledroomids"].clone(), - bannedroomids: db["bannedroomids"].clone(), - roomid_shortroomid: db["roomid_shortroomid"].clone(), - pduid_pdu: db["pduid_pdu"].clone(), - services: Services { - short: args.depend::("rooms::short"), - }, - } - } - - pub(super) fn exists(&self, room_id: &RoomId) -> Result { - let prefix = match self.services.short.get_shortroomid(room_id)? { - Some(b) => b.to_be_bytes().to_vec(), - None => return Ok(false), - }; - - // Look for PDUs in that room. - Ok(self - .pduid_pdu - .iter_from(&prefix, false) - .next() - .filter(|(k, _)| k.starts_with(&prefix)) - .is_some()) - } - - pub(super) fn iter_ids<'a>(&'a self) -> Box> + 'a> { - Box::new(self.roomid_shortroomid.iter().map(|(bytes, _)| { - RoomId::parse( - utils::string_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Room ID in publicroomids is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("Room ID in roomid_shortroomid is invalid.")) - })) - } - - #[inline] - pub(super) fn is_disabled(&self, room_id: &RoomId) -> Result { - Ok(self.disabledroomids.get(room_id.as_bytes())?.is_some()) - } - - #[inline] - pub(super) fn disable_room(&self, room_id: &RoomId, disabled: bool) -> Result<()> { - if disabled { - self.disabledroomids.insert(room_id.as_bytes(), &[])?; - } else { - self.disabledroomids.remove(room_id.as_bytes())?; - } - - Ok(()) - } - - #[inline] - pub(super) fn is_banned(&self, room_id: &RoomId) -> Result { - Ok(self.bannedroomids.get(room_id.as_bytes())?.is_some()) - } - - #[inline] - pub(super) fn ban_room(&self, room_id: &RoomId, banned: bool) -> Result<()> { - if banned { - self.bannedroomids.insert(room_id.as_bytes(), &[])?; - } else { - self.bannedroomids.remove(room_id.as_bytes())?; - } - - Ok(()) - } - - pub(super) fn list_banned_rooms<'a>(&'a self) -> Box> + 'a> { - Box::new(self.bannedroomids.iter().map( - |(room_id_bytes, _ /* non-banned rooms should not be in this table */)| { - let room_id = utils::string_from_bytes(&room_id_bytes) - .map_err(|e| { - error!("Invalid room_id bytes in bannedroomids: {e}"); - Error::bad_database("Invalid room_id in bannedroomids.") - })? - .try_into() - .map_err(|e| { - error!("Invalid room_id in bannedroomids: {e}"); - Error::bad_database("Invalid room_id in bannedroomids") - })?; - - Ok(room_id) - }, - )) - } -} diff --git a/src/service/rooms/metadata/mod.rs b/src/service/rooms/metadata/mod.rs index 7415c53b7..5d4a47c71 100644 --- a/src/service/rooms/metadata/mod.rs +++ b/src/service/rooms/metadata/mod.rs @@ -1,51 +1,92 @@ -mod data; - use std::sync::Arc; -use conduit::Result; -use ruma::{OwnedRoomId, RoomId}; +use conduit::{implement, utils::stream::TryIgnore, Result}; +use database::Map; +use futures::{Stream, StreamExt}; +use ruma::RoomId; -use self::data::Data; +use crate::{rooms, Dep}; pub struct Service { db: Data, + services: Services, +} + +struct Data { + disabledroomids: Arc, + bannedroomids: Arc, + roomid_shortroomid: Arc, + pduid_pdu: Arc, +} + +struct Services { + short: Dep, } impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { - db: Data::new(&args), + db: Data { + disabledroomids: args.db["disabledroomids"].clone(), + bannedroomids: args.db["bannedroomids"].clone(), + roomid_shortroomid: args.db["roomid_shortroomid"].clone(), + pduid_pdu: args.db["pduid_pdu"].clone(), + }, + services: Services { + short: args.depend::("rooms::short"), + }, })) } fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } -impl Service { - /// Checks if a room exists. - #[inline] - pub fn exists(&self, room_id: &RoomId) -> Result { self.db.exists(room_id) } +#[implement(Service)] +pub async fn exists(&self, room_id: &RoomId) -> bool { + let Ok(prefix) = self.services.short.get_shortroomid(room_id).await else { + return false; + }; + + // Look for PDUs in that room. + self.db + .pduid_pdu + .keys_raw_prefix(&prefix) + .ignore_err() + .next() + .await + .is_some() +} - #[must_use] - pub fn iter_ids<'a>(&'a self) -> Box> + 'a> { self.db.iter_ids() } +#[implement(Service)] +pub fn iter_ids(&self) -> impl Stream + Send + '_ { self.db.roomid_shortroomid.keys().ignore_err() } - #[inline] - pub fn is_disabled(&self, room_id: &RoomId) -> Result { self.db.is_disabled(room_id) } +#[implement(Service)] +#[inline] +pub fn disable_room(&self, room_id: &RoomId, disabled: bool) { + if disabled { + self.db.disabledroomids.insert(room_id.as_bytes(), &[]); + } else { + self.db.disabledroomids.remove(room_id.as_bytes()); + } +} - #[inline] - pub fn disable_room(&self, room_id: &RoomId, disabled: bool) -> Result<()> { - self.db.disable_room(room_id, disabled) +#[implement(Service)] +#[inline] +pub fn ban_room(&self, room_id: &RoomId, banned: bool) { + if banned { + self.db.bannedroomids.insert(room_id.as_bytes(), &[]); + } else { + self.db.bannedroomids.remove(room_id.as_bytes()); } +} - #[inline] - pub fn is_banned(&self, room_id: &RoomId) -> Result { self.db.is_banned(room_id) } +#[implement(Service)] +pub fn list_banned_rooms(&self) -> impl Stream + Send + '_ { self.db.bannedroomids.keys().ignore_err() } - #[inline] - pub fn ban_room(&self, room_id: &RoomId, banned: bool) -> Result<()> { self.db.ban_room(room_id, banned) } +#[implement(Service)] +#[inline] +pub async fn is_disabled(&self, room_id: &RoomId) -> bool { self.db.disabledroomids.qry(room_id).await.is_ok() } - #[inline] - #[must_use] - pub fn list_banned_rooms<'a>(&'a self) -> Box> + 'a> { - self.db.list_banned_rooms() - } -} +#[implement(Service)] +#[inline] +pub async fn is_banned(&self, room_id: &RoomId) -> bool { self.db.bannedroomids.qry(room_id).await.is_ok() } diff --git a/src/service/rooms/outlier/data.rs b/src/service/rooms/outlier/data.rs deleted file mode 100644 index aa804721b..000000000 --- a/src/service/rooms/outlier/data.rs +++ /dev/null @@ -1,42 +0,0 @@ -use std::sync::Arc; - -use conduit::{Error, Result}; -use database::{Database, Map}; -use ruma::{CanonicalJsonObject, EventId}; - -use crate::PduEvent; - -pub(super) struct Data { - eventid_outlierpdu: Arc, -} - -impl Data { - pub(super) fn new(db: &Arc) -> Self { - Self { - eventid_outlierpdu: db["eventid_outlierpdu"].clone(), - } - } - - pub(super) fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result> { - self.eventid_outlierpdu - .get(event_id.as_bytes())? - .map_or(Ok(None), |pdu| { - serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")) - }) - } - - pub(super) fn get_outlier_pdu(&self, event_id: &EventId) -> Result> { - self.eventid_outlierpdu - .get(event_id.as_bytes())? - .map_or(Ok(None), |pdu| { - serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")) - }) - } - - pub(super) fn add_pdu_outlier(&self, event_id: &EventId, pdu: &CanonicalJsonObject) -> Result<()> { - self.eventid_outlierpdu.insert( - event_id.as_bytes(), - &serde_json::to_vec(&pdu).expect("CanonicalJsonObject is valid"), - ) - } -} diff --git a/src/service/rooms/outlier/mod.rs b/src/service/rooms/outlier/mod.rs index 22bd2092a..b9d042638 100644 --- a/src/service/rooms/outlier/mod.rs +++ b/src/service/rooms/outlier/mod.rs @@ -1,9 +1,7 @@ -mod data; - use std::sync::Arc; -use conduit::Result; -use data::Data; +use conduit::{implement, Result}; +use database::{Deserialized, Map}; use ruma::{CanonicalJsonObject, EventId}; use crate::PduEvent; @@ -12,31 +10,48 @@ pub struct Service { db: Data, } +struct Data { + eventid_outlierpdu: Arc, +} + impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { - db: Data::new(args.db), + db: Data { + eventid_outlierpdu: args.db["eventid_outlierpdu"].clone(), + }, })) } fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } -impl Service { - /// Returns the pdu from the outlier tree. - pub fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result> { - self.db.get_outlier_pdu_json(event_id) - } +/// Returns the pdu from the outlier tree. +#[implement(Service)] +pub async fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result { + self.db + .eventid_outlierpdu + .get(event_id) + .await + .deserialized() +} - /// Returns the pdu from the outlier tree. - /// - /// TODO: use this? - #[allow(dead_code)] - pub fn get_pdu_outlier(&self, event_id: &EventId) -> Result> { self.db.get_outlier_pdu(event_id) } +/// Returns the pdu from the outlier tree. +#[implement(Service)] +pub async fn get_pdu_outlier(&self, event_id: &EventId) -> Result { + self.db + .eventid_outlierpdu + .get(event_id) + .await + .deserialized() +} - /// Append the PDU as an outlier. - #[tracing::instrument(skip(self, pdu), level = "debug")] - pub fn add_pdu_outlier(&self, event_id: &EventId, pdu: &CanonicalJsonObject) -> Result<()> { - self.db.add_pdu_outlier(event_id, pdu) - } +/// Append the PDU as an outlier. +#[implement(Service)] +#[tracing::instrument(skip(self, pdu), level = "debug")] +pub fn add_pdu_outlier(&self, event_id: &EventId, pdu: &CanonicalJsonObject) { + self.db.eventid_outlierpdu.insert( + event_id.as_bytes(), + &serde_json::to_vec(&pdu).expect("CanonicalJsonObject is valid"), + ); } diff --git a/src/service/rooms/pdu_metadata/data.rs b/src/service/rooms/pdu_metadata/data.rs index d1649da81..f23234752 100644 --- a/src/service/rooms/pdu_metadata/data.rs +++ b/src/service/rooms/pdu_metadata/data.rs @@ -1,7 +1,13 @@ use std::{mem::size_of, sync::Arc}; -use conduit::{utils, Error, PduCount, PduEvent, Result}; +use conduit::{ + result::LogErr, + utils, + utils::{stream::TryIgnore, ReadyExt}, + PduCount, PduEvent, +}; use database::Map; +use futures::{Stream, StreamExt}; use ruma::{EventId, RoomId, UserId}; use crate::{rooms, Dep}; @@ -17,8 +23,7 @@ struct Services { timeline: Dep, } -type PdusIterItem = Result<(PduCount, PduEvent)>; -type PdusIterator<'a> = Box + 'a>; +pub(super) type PdusIterItem = (PduCount, PduEvent); impl Data { pub(super) fn new(args: &crate::Args<'_>) -> Self { @@ -33,19 +38,17 @@ impl Data { } } - pub(super) fn add_relation(&self, from: u64, to: u64) -> Result<()> { + pub(super) fn add_relation(&self, from: u64, to: u64) { let mut key = to.to_be_bytes().to_vec(); key.extend_from_slice(&from.to_be_bytes()); - self.tofrom_relation.insert(&key, &[])?; - Ok(()) + self.tofrom_relation.insert(&key, &[]); } pub(super) fn relations_until<'a>( &'a self, user_id: &'a UserId, shortroomid: u64, target: u64, until: PduCount, - ) -> Result> { + ) -> impl Stream + Send + 'a + '_ { let prefix = target.to_be_bytes().to_vec(); let mut current = prefix.clone(); - let count_raw = match until { PduCount::Normal(x) => x.saturating_sub(1), PduCount::Backfilled(x) => { @@ -55,53 +58,42 @@ impl Data { }; current.extend_from_slice(&count_raw.to_be_bytes()); - Ok(Box::new( - self.tofrom_relation - .iter_from(¤t, true) - .take_while(move |(k, _)| k.starts_with(&prefix)) - .map(move |(tofrom, _data)| { - let from = utils::u64_from_bytes(&tofrom[(size_of::())..]) - .map_err(|_| Error::bad_database("Invalid count in tofrom_relation."))?; - - let mut pduid = shortroomid.to_be_bytes().to_vec(); - pduid.extend_from_slice(&from.to_be_bytes()); - - let mut pdu = self - .services - .timeline - .get_pdu_from_id(&pduid)? - .ok_or_else(|| Error::bad_database("Pdu in tofrom_relation is invalid."))?; - if pdu.sender != user_id { - pdu.remove_transaction_id()?; - } - Ok((PduCount::Normal(from), pdu)) - }), - )) + self.tofrom_relation + .rev_raw_keys_from(¤t) + .ignore_err() + .ready_take_while(move |key| key.starts_with(&prefix)) + .map(|to_from| utils::u64_from_u8(&to_from[(size_of::())..])) + .filter_map(move |from| async move { + let mut pduid = shortroomid.to_be_bytes().to_vec(); + pduid.extend_from_slice(&from.to_be_bytes()); + let mut pdu = self.services.timeline.get_pdu_from_id(&pduid).await.ok()?; + + if pdu.sender != user_id { + pdu.remove_transaction_id().log_err().ok(); + } + + Some((PduCount::Normal(from), pdu)) + }) } - pub(super) fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc]) -> Result<()> { + pub(super) fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc]) { for prev in event_ids { let mut key = room_id.as_bytes().to_vec(); key.extend_from_slice(prev.as_bytes()); - self.referencedevents.insert(&key, &[])?; + self.referencedevents.insert(&key, &[]); } - - Ok(()) } - pub(super) fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> Result { - let mut key = room_id.as_bytes().to_vec(); - key.extend_from_slice(event_id.as_bytes()); - Ok(self.referencedevents.get(&key)?.is_some()) + pub(super) async fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> bool { + let key = (room_id, event_id); + self.referencedevents.qry(&key).await.is_ok() } - pub(super) fn mark_event_soft_failed(&self, event_id: &EventId) -> Result<()> { - self.softfailedeventids.insert(event_id.as_bytes(), &[]) + pub(super) fn mark_event_soft_failed(&self, event_id: &EventId) { + self.softfailedeventids.insert(event_id.as_bytes(), &[]); } - pub(super) fn is_event_soft_failed(&self, event_id: &EventId) -> Result { - self.softfailedeventids - .get(event_id.as_bytes()) - .map(|o| o.is_some()) + pub(super) async fn is_event_soft_failed(&self, event_id: &EventId) -> bool { + self.softfailedeventids.qry(event_id).await.is_ok() } } diff --git a/src/service/rooms/pdu_metadata/mod.rs b/src/service/rooms/pdu_metadata/mod.rs index d9eaf3244..dbaebfbf3 100644 --- a/src/service/rooms/pdu_metadata/mod.rs +++ b/src/service/rooms/pdu_metadata/mod.rs @@ -1,8 +1,8 @@ mod data; - use std::sync::Arc; -use conduit::{PduCount, PduEvent, Result}; +use conduit::{utils::stream::IterStream, PduCount, Result}; +use futures::StreamExt; use ruma::{ api::{client::relations::get_relating_events, Direction}, events::{relation::RelationType, TimelineEventType}, @@ -10,7 +10,7 @@ use ruma::{ }; use serde::Deserialize; -use self::data::Data; +use self::data::{Data, PdusIterItem}; use crate::{rooms, Dep}; pub struct Service { @@ -51,21 +51,19 @@ impl crate::Service for Service { impl Service { #[tracing::instrument(skip(self, from, to), level = "debug")] - pub fn add_relation(&self, from: PduCount, to: PduCount) -> Result<()> { + pub fn add_relation(&self, from: PduCount, to: PduCount) { match (from, to) { (PduCount::Normal(f), PduCount::Normal(t)) => self.db.add_relation(f, t), _ => { // TODO: Relations with backfilled pdus - - Ok(()) }, } } #[allow(clippy::too_many_arguments)] - pub fn paginate_relations_with_filter( - &self, sender_user: &UserId, room_id: &RoomId, target: &EventId, filter_event_type: &Option, - filter_rel_type: &Option, from: &Option, to: &Option, limit: &Option, + pub async fn paginate_relations_with_filter( + &self, sender_user: &UserId, room_id: &RoomId, target: &EventId, filter_event_type: Option, + filter_rel_type: Option, from: Option<&String>, to: Option<&String>, limit: Option, recurse: bool, dir: Direction, ) -> Result { let from = match from { @@ -76,7 +74,7 @@ impl Service { }, }; - let to = to.as_ref().and_then(|t| PduCount::try_from_string(t).ok()); + let to = to.and_then(|t| PduCount::try_from_string(t).ok()); // Use limit or else 10, with maximum 100 let limit = limit @@ -92,30 +90,32 @@ impl Service { 1 }; - let relations_until = &self.relations_until(sender_user, room_id, target, from, depth)?; - let events: Vec<_> = relations_until // TODO: should be relations_after - .iter() - .filter(|(_, pdu)| { - filter_event_type.as_ref().map_or(true, |t| &pdu.kind == t) - && if let Ok(content) = - serde_json::from_str::(pdu.content.get()) - { - filter_rel_type - .as_ref() - .map_or(true, |r| &content.relates_to.rel_type == r) - } else { - false - } - }) - .take(limit) - .filter(|(_, pdu)| { - self.services - .state_accessor - .user_can_see_event(sender_user, room_id, &pdu.event_id) - .unwrap_or(false) - }) - .take_while(|(k, _)| Some(k) != to.as_ref()) // Stop at `to` - .collect(); + let relations_until: Vec = self + .relations_until(sender_user, room_id, target, from, depth) + .await?; + + // TODO: should be relations_after + let events: Vec<_> = relations_until + .into_iter() + .filter(move |(_, pdu): &PdusIterItem| { + if !filter_event_type.as_ref().map_or(true, |t| pdu.kind == *t) { + return false; + } + + let Ok(content) = serde_json::from_str::(pdu.content.get()) else { + return false; + }; + + filter_rel_type + .as_ref() + .map_or(true, |r| *r == content.relates_to.rel_type) + }) + .take(limit) + .take_while(|(k, _)| Some(*k) != to) + .stream() + .filter_map(|item| self.visibility_filter(sender_user, item)) + .collect() + .await; let next_token = events.last().map(|(count, _)| count).copied(); @@ -125,9 +125,9 @@ impl Service { .map(|(_, pdu)| pdu.to_message_like_event()) .collect(), Direction::Backward => events - .into_iter() - .rev() // relations are always most recent first - .map(|(_, pdu)| pdu.to_message_like_event()) + .into_iter() + .rev() // relations are always most recent first + .map(|(_, pdu)| pdu.to_message_like_event()) .collect(), }; @@ -135,68 +135,85 @@ impl Service { chunk: events_chunk, next_batch: next_token.map(|t| t.stringify()), prev_batch: Some(from.stringify()), - recursion_depth: if recurse { - Some(depth.into()) - } else { - None - }, + recursion_depth: recurse.then_some(depth.into()), }) } - pub fn relations_until<'a>( - &'a self, user_id: &'a UserId, room_id: &'a RoomId, target: &'a EventId, until: PduCount, max_depth: u8, - ) -> Result> { - let room_id = self.services.short.get_or_create_shortroomid(room_id)?; - #[allow(unknown_lints)] - #[allow(clippy::manual_unwrap_or_default)] - let target = match self.services.timeline.get_pdu_count(target)? { - Some(PduCount::Normal(c)) => c, + async fn visibility_filter(&self, sender_user: &UserId, item: PdusIterItem) -> Option { + let (_, pdu) = &item; + + self.services + .state_accessor + .user_can_see_event(sender_user, &pdu.room_id, &pdu.event_id) + .await + .then_some(item) + } + + pub async fn relations_until( + &self, user_id: &UserId, room_id: &RoomId, target: &EventId, until: PduCount, max_depth: u8, + ) -> Result> { + let room_id = self.services.short.get_or_create_shortroomid(room_id).await; + + let target = match self.services.timeline.get_pdu_count(target).await { + Ok(PduCount::Normal(c)) => c, // TODO: Support backfilled relations _ => 0, // This will result in an empty iterator }; - self.db + let mut pdus: Vec = self + .db .relations_until(user_id, room_id, target, until) - .map(|mut relations| { - let mut pdus: Vec<_> = (*relations).into_iter().filter_map(Result::ok).collect(); - let mut stack: Vec<_> = pdus.clone().iter().map(|pdu| (pdu.to_owned(), 1)).collect(); - - while let Some(stack_pdu) = stack.pop() { - let target = match stack_pdu.0 .0 { - PduCount::Normal(c) => c, - // TODO: Support backfilled relations - PduCount::Backfilled(_) => 0, // This will result in an empty iterator - }; - - if let Ok(relations) = self.db.relations_until(user_id, room_id, target, until) { - for relation in relations.flatten() { - if stack_pdu.1 < max_depth { - stack.push((relation.clone(), stack_pdu.1.saturating_add(1))); - } - - pdus.push(relation); - } - } + .collect() + .await; + + let mut stack: Vec<_> = pdus.clone().into_iter().map(|pdu| (pdu, 1)).collect(); + + while let Some(stack_pdu) = stack.pop() { + let target = match stack_pdu.0 .0 { + PduCount::Normal(c) => c, + // TODO: Support backfilled relations + PduCount::Backfilled(_) => 0, // This will result in an empty iterator + }; + + let relations: Vec = self + .db + .relations_until(user_id, room_id, target, until) + .collect() + .await; + + for relation in relations { + if stack_pdu.1 < max_depth { + stack.push((relation.clone(), stack_pdu.1.saturating_add(1))); } - pdus.sort_by(|a, b| a.0.cmp(&b.0)); - pdus - }) + pdus.push(relation); + } + } + + pdus.sort_by(|a, b| a.0.cmp(&b.0)); + + Ok(pdus) } + #[inline] #[tracing::instrument(skip_all, level = "debug")] - pub fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc]) -> Result<()> { - self.db.mark_as_referenced(room_id, event_ids) + pub fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc]) { + self.db.mark_as_referenced(room_id, event_ids); } + #[inline] #[tracing::instrument(skip(self), level = "debug")] - pub fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> Result { - self.db.is_event_referenced(room_id, event_id) + pub async fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> bool { + self.db.is_event_referenced(room_id, event_id).await } + #[inline] #[tracing::instrument(skip(self), level = "debug")] - pub fn mark_event_soft_failed(&self, event_id: &EventId) -> Result<()> { self.db.mark_event_soft_failed(event_id) } + pub fn mark_event_soft_failed(&self, event_id: &EventId) { self.db.mark_event_soft_failed(event_id) } + #[inline] #[tracing::instrument(skip(self), level = "debug")] - pub fn is_event_soft_failed(&self, event_id: &EventId) -> Result { self.db.is_event_soft_failed(event_id) } + pub async fn is_event_soft_failed(&self, event_id: &EventId) -> bool { + self.db.is_event_soft_failed(event_id).await + } } diff --git a/src/service/rooms/read_receipt/data.rs b/src/service/rooms/read_receipt/data.rs index 0c156df38..a2c0fabca 100644 --- a/src/service/rooms/read_receipt/data.rs +++ b/src/service/rooms/read_receipt/data.rs @@ -1,10 +1,18 @@ use std::{mem::size_of, sync::Arc}; -use conduit::{utils, Error, Result}; -use database::Map; -use ruma::{events::receipt::ReceiptEvent, serde::Raw, CanonicalJsonObject, RoomId, UserId}; +use conduit::{ + utils, + utils::{stream::TryIgnore, ReadyExt}, + Error, Result, +}; +use database::{Deserialized, Map}; +use futures::{Stream, StreamExt}; +use ruma::{ + events::{receipt::ReceiptEvent, AnySyncEphemeralRoomEvent}, + serde::Raw, + CanonicalJsonObject, OwnedUserId, RoomId, UserId, +}; -use super::AnySyncEphemeralRoomEventIter; use crate::{globals, Dep}; pub(super) struct Data { @@ -18,6 +26,8 @@ struct Services { globals: Dep, } +pub(super) type ReceiptItem = (OwnedUserId, u64, Raw); + impl Data { pub(super) fn new(args: &crate::Args<'_>) -> Self { let db = &args.db; @@ -31,7 +41,9 @@ impl Data { } } - pub(super) fn readreceipt_update(&self, user_id: &UserId, room_id: &RoomId, event: &ReceiptEvent) -> Result<()> { + pub(super) async fn readreceipt_update(&self, user_id: &UserId, room_id: &RoomId, event: &ReceiptEvent) { + type KeyVal<'a> = (&'a RoomId, u64, &'a UserId); + let mut prefix = room_id.as_bytes().to_vec(); prefix.push(0xFF); @@ -39,108 +51,90 @@ impl Data { last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes()); // Remove old entry - if let Some((old, _)) = self - .readreceiptid_readreceipt - .iter_from(&last_possible_key, true) - .take_while(|(key, _)| key.starts_with(&prefix)) - .find(|(key, _)| { - key.rsplit(|&b| b == 0xFF) - .next() - .expect("rsplit always returns an element") - == user_id.as_bytes() - }) { - // This is the old room_latest - self.readreceiptid_readreceipt.remove(&old)?; - } + self.readreceiptid_readreceipt + .rev_keys_from_raw(&last_possible_key) + .ignore_err() + .ready_take_while(|(r, ..): &KeyVal<'_>| *r == room_id) + .ready_filter_map(|(r, c, u): KeyVal<'_>| (u == user_id).then_some((r, c, u))) + .ready_for_each(|old: KeyVal<'_>| { + // This is the old room_latest + self.readreceiptid_readreceipt.del(&old); + }) + .await; let mut room_latest_id = prefix; - room_latest_id.extend_from_slice(&self.services.globals.next_count()?.to_be_bytes()); + room_latest_id.extend_from_slice(&self.services.globals.next_count().unwrap().to_be_bytes()); room_latest_id.push(0xFF); room_latest_id.extend_from_slice(user_id.as_bytes()); self.readreceiptid_readreceipt.insert( &room_latest_id, &serde_json::to_vec(event).expect("EduEvent::to_string always works"), - )?; - - Ok(()) + ); } - pub(super) fn readreceipts_since<'a>(&'a self, room_id: &RoomId, since: u64) -> AnySyncEphemeralRoomEventIter<'a> { + pub(super) fn readreceipts_since<'a>( + &'a self, room_id: &'a RoomId, since: u64, + ) -> impl Stream + Send + 'a { + let after_since = since.saturating_add(1); // +1 so we don't send the event at since + let first_possible_edu = (room_id, after_since); + let mut prefix = room_id.as_bytes().to_vec(); prefix.push(0xFF); let prefix2 = prefix.clone(); - let mut first_possible_edu = prefix.clone(); - first_possible_edu.extend_from_slice(&(since.saturating_add(1)).to_be_bytes()); // +1 so we don't send the event at since - - Box::new( - self.readreceiptid_readreceipt - .iter_from(&first_possible_edu, false) - .take_while(move |(k, _)| k.starts_with(&prefix2)) - .map(move |(k, v)| { - let count_offset = prefix.len().saturating_add(size_of::()); - let count = utils::u64_from_bytes(&k[prefix.len()..count_offset]) - .map_err(|_| Error::bad_database("Invalid readreceiptid count in db."))?; - let user_id_offset = count_offset.saturating_add(1); - let user_id = UserId::parse( - utils::string_from_bytes(&k[user_id_offset..]) - .map_err(|_| Error::bad_database("Invalid readreceiptid userid bytes in db."))?, - ) + self.readreceiptid_readreceipt + .stream_raw_from(&first_possible_edu) + .ignore_err() + .ready_take_while(move |(k, _)| k.starts_with(&prefix2)) + .map(move |(k, v)| { + let count_offset = prefix.len().saturating_add(size_of::()); + let user_id_offset = count_offset.saturating_add(1); + + let count = utils::u64_from_bytes(&k[prefix.len()..count_offset]) + .map_err(|_| Error::bad_database("Invalid readreceiptid count in db."))?; + + let user_id_str = utils::string_from_bytes(&k[user_id_offset..]) + .map_err(|_| Error::bad_database("Invalid readreceiptid userid bytes in db."))?; + + let user_id = UserId::parse(user_id_str) .map_err(|_| Error::bad_database("Invalid readreceiptid userid in db."))?; - let mut json = serde_json::from_slice::(&v) - .map_err(|_| Error::bad_database("Read receipt in roomlatestid_roomlatest is invalid json."))?; - json.remove("room_id"); - - Ok(( - user_id, - count, - Raw::from_json(serde_json::value::to_raw_value(&json).expect("json is valid raw value")), - )) - }), - ) - } + let mut json = serde_json::from_slice::(v) + .map_err(|_| Error::bad_database("Read receipt in roomlatestid_roomlatest is invalid json."))?; - pub(super) fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()> { - let mut key = room_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(user_id.as_bytes()); + json.remove("room_id"); - self.roomuserid_privateread - .insert(&key, &count.to_be_bytes())?; + let event = Raw::from_json(serde_json::value::to_raw_value(&json)?); - self.roomuserid_lastprivatereadupdate - .insert(&key, &self.services.globals.next_count()?.to_be_bytes()) + Ok((user_id, count, event)) + }) + .ignore_err() } - pub(super) fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result> { + pub(super) fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) { let mut key = room_id.as_bytes().to_vec(); key.push(0xFF); key.extend_from_slice(user_id.as_bytes()); self.roomuserid_privateread - .get(&key)? - .map_or(Ok(None), |v| { - Ok(Some( - utils::u64_from_bytes(&v).map_err(|_| Error::bad_database("Invalid private read marker bytes"))?, - )) - }) + .insert(&key, &count.to_be_bytes()); + + self.roomuserid_lastprivatereadupdate + .insert(&key, &self.services.globals.next_count().unwrap().to_be_bytes()); } - pub(super) fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> Result { - let mut key = room_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(user_id.as_bytes()); + pub(super) async fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result { + let key = (room_id, user_id); + self.roomuserid_privateread.qry(&key).await.deserialized() + } - Ok(self - .roomuserid_lastprivatereadupdate - .get(&key)? - .map(|bytes| { - utils::u64_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Count in roomuserid_lastprivatereadupdate is invalid.")) - }) - .transpose()? - .unwrap_or(0)) + pub(super) async fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> u64 { + let key = (room_id, user_id); + self.roomuserid_lastprivatereadupdate + .qry(&key) + .await + .deserialized() + .unwrap_or(0) } } diff --git a/src/service/rooms/read_receipt/mod.rs b/src/service/rooms/read_receipt/mod.rs index da11e2a0f..ec34361e0 100644 --- a/src/service/rooms/read_receipt/mod.rs +++ b/src/service/rooms/read_receipt/mod.rs @@ -3,16 +3,17 @@ mod data; use std::{collections::BTreeMap, sync::Arc}; use conduit::{debug, Result}; -use data::Data; +use futures::Stream; use ruma::{ events::{ receipt::{ReceiptEvent, ReceiptEventContent}, - AnySyncEphemeralRoomEvent, SyncEphemeralRoomEvent, + SyncEphemeralRoomEvent, }, serde::Raw, - OwnedUserId, RoomId, UserId, + RoomId, UserId, }; +use self::data::{Data, ReceiptItem}; use crate::{sending, Dep}; pub struct Service { @@ -24,9 +25,6 @@ struct Services { sending: Dep, } -type AnySyncEphemeralRoomEventIter<'a> = - Box)>> + 'a>; - impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { @@ -42,44 +40,53 @@ impl crate::Service for Service { impl Service { /// Replaces the previous read receipt. - pub fn readreceipt_update(&self, user_id: &UserId, room_id: &RoomId, event: &ReceiptEvent) -> Result<()> { - self.db.readreceipt_update(user_id, room_id, event)?; - self.services.sending.flush_room(room_id)?; - - Ok(()) + pub async fn readreceipt_update(&self, user_id: &UserId, room_id: &RoomId, event: &ReceiptEvent) { + self.db.readreceipt_update(user_id, room_id, event).await; + self.services + .sending + .flush_room(room_id) + .await + .expect("room flush failed"); } /// Returns an iterator over the most recent read_receipts in a room that /// happened after the event with id `since`. + #[inline] #[tracing::instrument(skip(self), level = "debug")] pub fn readreceipts_since<'a>( - &'a self, room_id: &RoomId, since: u64, - ) -> impl Iterator)>> + 'a { + &'a self, room_id: &'a RoomId, since: u64, + ) -> impl Stream + Send + 'a { self.db.readreceipts_since(room_id, since) } /// Sets a private read marker at `count`. + #[inline] #[tracing::instrument(skip(self), level = "debug")] - pub fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()> { - self.db.private_read_set(room_id, user_id, count) + pub fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) { + self.db.private_read_set(room_id, user_id, count); } /// Returns the private read marker. + #[inline] #[tracing::instrument(skip(self), level = "debug")] - pub fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result> { - self.db.private_read_get(room_id, user_id) + pub async fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result { + self.db.private_read_get(room_id, user_id).await } /// Returns the count of the last typing update in this room. - pub fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> Result { - self.db.last_privateread_update(user_id, room_id) + #[inline] + pub async fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> u64 { + self.db.last_privateread_update(user_id, room_id).await } } #[must_use] -pub fn pack_receipts(receipts: AnySyncEphemeralRoomEventIter<'_>) -> Raw> { +pub fn pack_receipts(receipts: I) -> Raw> +where + I: Iterator, +{ let mut json = BTreeMap::new(); - for (_user, _count, value) in receipts.flatten() { + for (_, _, value) in receipts { let receipt = serde_json::from_str::>(value.json().get()); if let Ok(value) = receipt { for (event, receipt) in value.content { diff --git a/src/service/rooms/search/data.rs b/src/service/rooms/search/data.rs index a0086095b..de98beeeb 100644 --- a/src/service/rooms/search/data.rs +++ b/src/service/rooms/search/data.rs @@ -1,13 +1,12 @@ use std::sync::Arc; -use conduit::{utils, Result}; +use conduit::utils::{set, stream::TryIgnore, IterStream, ReadyExt}; use database::Map; +use futures::StreamExt; use ruma::RoomId; use crate::{rooms, Dep}; -type SearchPdusResult<'a> = Result> + 'a>, Vec)>>; - pub(super) struct Data { tokenids: Arc, services: Services, @@ -28,7 +27,7 @@ impl Data { } } - pub(super) fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()> { + pub(super) fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) { let batch = tokenize(message_body) .map(|word| { let mut key = shortroomid.to_be_bytes().to_vec(); @@ -39,11 +38,10 @@ impl Data { }) .collect::>(); - self.tokenids - .insert_batch(batch.iter().map(database::KeyVal::from)) + self.tokenids.insert_batch(batch.iter()); } - pub(super) fn deindex_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()> { + pub(super) fn deindex_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) { let batch = tokenize(message_body).map(|word| { let mut key = shortroomid.to_be_bytes().to_vec(); key.extend_from_slice(word.as_bytes()); @@ -53,46 +51,53 @@ impl Data { }); for token in batch { - self.tokenids.remove(&token)?; + self.tokenids.remove(&token); } - - Ok(()) } - pub(super) fn search_pdus<'a>(&'a self, room_id: &RoomId, search_string: &str) -> SearchPdusResult<'a> { + pub(super) async fn search_pdus( + &self, room_id: &RoomId, search_string: &str, + ) -> Option<(Vec>, Vec)> { let prefix = self .services .short - .get_shortroomid(room_id)? - .expect("room exists") + .get_shortroomid(room_id) + .await + .ok()? .to_be_bytes() .to_vec(); let words: Vec<_> = tokenize(search_string).collect(); - let iterators = words.clone().into_iter().map(move |word| { - let mut prefix2 = prefix.clone(); - prefix2.extend_from_slice(word.as_bytes()); - prefix2.push(0xFF); - let prefix3 = prefix2.clone(); - - let mut last_possible_id = prefix2.clone(); - last_possible_id.extend_from_slice(&u64::MAX.to_be_bytes()); - - self.tokenids - .iter_from(&last_possible_id, true) // Newest pdus first - .take_while(move |(k, _)| k.starts_with(&prefix2)) - .map(move |(key, _)| key[prefix3.len()..].to_vec()) - }); - - let Some(common_elements) = utils::common_elements(iterators, |a, b| { - // We compare b with a because we reversed the iterator earlier - b.cmp(a) - }) else { - return Ok(None); - }; - - Ok(Some((Box::new(common_elements), words))) + let bufs: Vec<_> = words + .clone() + .into_iter() + .stream() + .then(move |word| { + let mut prefix2 = prefix.clone(); + prefix2.extend_from_slice(word.as_bytes()); + prefix2.push(0xFF); + let prefix3 = prefix2.clone(); + + let mut last_possible_id = prefix2.clone(); + last_possible_id.extend_from_slice(&u64::MAX.to_be_bytes()); + + self.tokenids + .rev_raw_keys_from(&last_possible_id) // Newest pdus first + .ignore_err() + .ready_take_while(move |key| key.starts_with(&prefix2)) + .map(move |key| key[prefix3.len()..].to_vec()) + .collect::>() + }) + .collect() + .await; + + Some(( + set::intersection(bufs.iter().map(|buf| buf.iter())) + .cloned() + .collect(), + words, + )) } } @@ -100,7 +105,7 @@ impl Data { /// /// This may be used to tokenize both message bodies (for indexing) or search /// queries (for querying). -fn tokenize(body: &str) -> impl Iterator + '_ { +fn tokenize(body: &str) -> impl Iterator + Send + '_ { body.split_terminator(|c: char| !c.is_alphanumeric()) .filter(|s| !s.is_empty()) .filter(|word| word.len() <= 50) diff --git a/src/service/rooms/search/mod.rs b/src/service/rooms/search/mod.rs index 8caa0ce35..80b588044 100644 --- a/src/service/rooms/search/mod.rs +++ b/src/service/rooms/search/mod.rs @@ -21,20 +21,21 @@ impl crate::Service for Service { } impl Service { + #[inline] #[tracing::instrument(skip(self), level = "debug")] - pub fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()> { - self.db.index_pdu(shortroomid, pdu_id, message_body) + pub fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) { + self.db.index_pdu(shortroomid, pdu_id, message_body); } + #[inline] #[tracing::instrument(skip(self), level = "debug")] - pub fn deindex_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()> { - self.db.deindex_pdu(shortroomid, pdu_id, message_body) + pub fn deindex_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) { + self.db.deindex_pdu(shortroomid, pdu_id, message_body); } + #[inline] #[tracing::instrument(skip(self), level = "debug")] - pub fn search_pdus<'a>( - &'a self, room_id: &RoomId, search_string: &str, - ) -> Result> + 'a, Vec)>> { - self.db.search_pdus(room_id, search_string) + pub async fn search_pdus(&self, room_id: &RoomId, search_string: &str) -> Option<(Vec>, Vec)> { + self.db.search_pdus(room_id, search_string).await } } diff --git a/src/service/rooms/short/data.rs b/src/service/rooms/short/data.rs deleted file mode 100644 index 17fbb64e8..000000000 --- a/src/service/rooms/short/data.rs +++ /dev/null @@ -1,195 +0,0 @@ -use std::sync::Arc; - -use conduit::{utils, warn, Error, Result}; -use database::Map; -use ruma::{events::StateEventType, EventId, RoomId}; - -use crate::{globals, Dep}; - -pub(super) struct Data { - eventid_shorteventid: Arc, - shorteventid_eventid: Arc, - statekey_shortstatekey: Arc, - shortstatekey_statekey: Arc, - roomid_shortroomid: Arc, - statehash_shortstatehash: Arc, - services: Services, -} - -struct Services { - globals: Dep, -} - -impl Data { - pub(super) fn new(args: &crate::Args<'_>) -> Self { - let db = &args.db; - Self { - eventid_shorteventid: db["eventid_shorteventid"].clone(), - shorteventid_eventid: db["shorteventid_eventid"].clone(), - statekey_shortstatekey: db["statekey_shortstatekey"].clone(), - shortstatekey_statekey: db["shortstatekey_statekey"].clone(), - roomid_shortroomid: db["roomid_shortroomid"].clone(), - statehash_shortstatehash: db["statehash_shortstatehash"].clone(), - services: Services { - globals: args.depend::("globals"), - }, - } - } - - pub(super) fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result { - let short = if let Some(shorteventid) = self.eventid_shorteventid.get(event_id.as_bytes())? { - utils::u64_from_bytes(&shorteventid).map_err(|_| Error::bad_database("Invalid shorteventid in db."))? - } else { - let shorteventid = self.services.globals.next_count()?; - self.eventid_shorteventid - .insert(event_id.as_bytes(), &shorteventid.to_be_bytes())?; - self.shorteventid_eventid - .insert(&shorteventid.to_be_bytes(), event_id.as_bytes())?; - shorteventid - }; - - Ok(short) - } - - pub(super) fn multi_get_or_create_shorteventid(&self, event_ids: &[&EventId]) -> Result> { - let mut ret: Vec = Vec::with_capacity(event_ids.len()); - let keys = event_ids - .iter() - .map(|id| id.as_bytes()) - .collect::>(); - for (i, short) in self - .eventid_shorteventid - .multi_get(&keys)? - .iter() - .enumerate() - { - #[allow(clippy::single_match_else)] - match short { - Some(short) => ret.push( - utils::u64_from_bytes(short).map_err(|_| Error::bad_database("Invalid shorteventid in db."))?, - ), - None => { - let short = self.services.globals.next_count()?; - self.eventid_shorteventid - .insert(keys[i], &short.to_be_bytes())?; - self.shorteventid_eventid - .insert(&short.to_be_bytes(), keys[i])?; - - debug_assert!(ret.len() == i, "position of result must match input"); - ret.push(short); - }, - } - } - - Ok(ret) - } - - pub(super) fn get_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result> { - let mut statekey_vec = event_type.to_string().as_bytes().to_vec(); - statekey_vec.push(0xFF); - statekey_vec.extend_from_slice(state_key.as_bytes()); - - let short = self - .statekey_shortstatekey - .get(&statekey_vec)? - .map(|shortstatekey| { - utils::u64_from_bytes(&shortstatekey).map_err(|_| Error::bad_database("Invalid shortstatekey in db.")) - }) - .transpose()?; - - Ok(short) - } - - pub(super) fn get_or_create_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result { - let mut statekey_vec = event_type.to_string().as_bytes().to_vec(); - statekey_vec.push(0xFF); - statekey_vec.extend_from_slice(state_key.as_bytes()); - - let short = if let Some(shortstatekey) = self.statekey_shortstatekey.get(&statekey_vec)? { - utils::u64_from_bytes(&shortstatekey).map_err(|_| Error::bad_database("Invalid shortstatekey in db."))? - } else { - let shortstatekey = self.services.globals.next_count()?; - self.statekey_shortstatekey - .insert(&statekey_vec, &shortstatekey.to_be_bytes())?; - self.shortstatekey_statekey - .insert(&shortstatekey.to_be_bytes(), &statekey_vec)?; - shortstatekey - }; - - Ok(short) - } - - pub(super) fn get_eventid_from_short(&self, shorteventid: u64) -> Result> { - let bytes = self - .shorteventid_eventid - .get(&shorteventid.to_be_bytes())? - .ok_or_else(|| Error::bad_database("Shorteventid does not exist"))?; - - let event_id = EventId::parse_arc( - utils::string_from_bytes(&bytes) - .map_err(|_| Error::bad_database("EventID in shorteventid_eventid is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("EventId in shorteventid_eventid is invalid."))?; - - Ok(event_id) - } - - pub(super) fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)> { - let bytes = self - .shortstatekey_statekey - .get(&shortstatekey.to_be_bytes())? - .ok_or_else(|| Error::bad_database("Shortstatekey does not exist"))?; - - let mut parts = bytes.splitn(2, |&b| b == 0xFF); - let eventtype_bytes = parts.next().expect("split always returns one entry"); - let statekey_bytes = parts - .next() - .ok_or_else(|| Error::bad_database("Invalid statekey in shortstatekey_statekey."))?; - - let event_type = StateEventType::from(utils::string_from_bytes(eventtype_bytes).map_err(|e| { - warn!("Event type in shortstatekey_statekey is invalid: {}", e); - Error::bad_database("Event type in shortstatekey_statekey is invalid.") - })?); - - let state_key = utils::string_from_bytes(statekey_bytes) - .map_err(|_| Error::bad_database("Statekey in shortstatekey_statekey is invalid unicode."))?; - - let result = (event_type, state_key); - - Ok(result) - } - - /// Returns (shortstatehash, already_existed) - pub(super) fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> Result<(u64, bool)> { - Ok(if let Some(shortstatehash) = self.statehash_shortstatehash.get(state_hash)? { - ( - utils::u64_from_bytes(&shortstatehash) - .map_err(|_| Error::bad_database("Invalid shortstatehash in db."))?, - true, - ) - } else { - let shortstatehash = self.services.globals.next_count()?; - self.statehash_shortstatehash - .insert(state_hash, &shortstatehash.to_be_bytes())?; - (shortstatehash, false) - }) - } - - pub(super) fn get_shortroomid(&self, room_id: &RoomId) -> Result> { - self.roomid_shortroomid - .get(room_id.as_bytes())? - .map(|bytes| utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid shortroomid in db."))) - .transpose() - } - - pub(super) fn get_or_create_shortroomid(&self, room_id: &RoomId) -> Result { - Ok(if let Some(short) = self.roomid_shortroomid.get(room_id.as_bytes())? { - utils::u64_from_bytes(&short).map_err(|_| Error::bad_database("Invalid shortroomid in db."))? - } else { - let short = self.services.globals.next_count()?; - self.roomid_shortroomid - .insert(room_id.as_bytes(), &short.to_be_bytes())?; - short - }) - } -} diff --git a/src/service/rooms/short/mod.rs b/src/service/rooms/short/mod.rs index bfe0e9a0e..20082da23 100644 --- a/src/service/rooms/short/mod.rs +++ b/src/service/rooms/short/mod.rs @@ -1,59 +1,217 @@ -mod data; - use std::sync::Arc; -use conduit::Result; +use conduit::{err, implement, utils, Result}; +use database::{Deserialized, Map}; use ruma::{events::StateEventType, EventId, RoomId}; -use self::data::Data; +use crate::{globals, Dep}; pub struct Service { db: Data, + services: Services, +} + +struct Data { + eventid_shorteventid: Arc, + shorteventid_eventid: Arc, + statekey_shortstatekey: Arc, + shortstatekey_statekey: Arc, + roomid_shortroomid: Arc, + statehash_shortstatehash: Arc, +} + +struct Services { + globals: Dep, } impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { - db: Data::new(&args), + db: Data { + eventid_shorteventid: args.db["eventid_shorteventid"].clone(), + shorteventid_eventid: args.db["shorteventid_eventid"].clone(), + statekey_shortstatekey: args.db["statekey_shortstatekey"].clone(), + shortstatekey_statekey: args.db["shortstatekey_statekey"].clone(), + roomid_shortroomid: args.db["roomid_shortroomid"].clone(), + statehash_shortstatehash: args.db["statehash_shortstatehash"].clone(), + }, + services: Services { + globals: args.depend::("globals"), + }, })) } fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } -impl Service { - pub fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result { - self.db.get_or_create_shorteventid(event_id) +#[implement(Service)] +pub async fn get_or_create_shorteventid(&self, event_id: &EventId) -> u64 { + if let Ok(shorteventid) = self + .db + .eventid_shorteventid + .get(event_id) + .await + .deserialized() + { + return shorteventid; } - pub fn multi_get_or_create_shorteventid(&self, event_ids: &[&EventId]) -> Result> { - self.db.multi_get_or_create_shorteventid(event_ids) - } + let shorteventid = self.services.globals.next_count().unwrap(); + self.db + .eventid_shorteventid + .insert(event_id.as_bytes(), &shorteventid.to_be_bytes()); + self.db + .shorteventid_eventid + .insert(&shorteventid.to_be_bytes(), event_id.as_bytes()); - pub fn get_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result> { - self.db.get_shortstatekey(event_type, state_key) - } + shorteventid +} - pub fn get_or_create_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result { - self.db.get_or_create_shortstatekey(event_type, state_key) - } +#[implement(Service)] +pub async fn multi_get_or_create_shorteventid(&self, event_ids: &[&EventId]) -> Vec { + self.db + .eventid_shorteventid + .get_batch_blocking(event_ids.iter()) + .into_iter() + .enumerate() + .map(|(i, result)| match result { + Ok(ref short) => utils::u64_from_u8(short), + Err(_) => { + let short = self.services.globals.next_count().unwrap(); + self.db + .eventid_shorteventid + .insert(event_ids[i], &short.to_be_bytes()); + self.db + .shorteventid_eventid + .insert(&short.to_be_bytes(), event_ids[i]); - pub fn get_eventid_from_short(&self, shorteventid: u64) -> Result> { - self.db.get_eventid_from_short(shorteventid) - } + short + }, + }) + .collect() +} - pub fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)> { - self.db.get_statekey_from_short(shortstatekey) - } +#[implement(Service)] +pub async fn get_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result { + let key = (event_type, state_key); + self.db + .statekey_shortstatekey + .qry(&key) + .await + .deserialized() +} - /// Returns (shortstatehash, already_existed) - pub fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> Result<(u64, bool)> { - self.db.get_or_create_shortstatehash(state_hash) +#[implement(Service)] +pub async fn get_or_create_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> u64 { + let key = (event_type.to_string(), state_key); + if let Ok(shortstatekey) = self + .db + .statekey_shortstatekey + .qry(&key) + .await + .deserialized() + { + return shortstatekey; } - pub fn get_shortroomid(&self, room_id: &RoomId) -> Result> { self.db.get_shortroomid(room_id) } + let mut key = event_type.to_string().as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(state_key.as_bytes()); + + let shortstatekey = self.services.globals.next_count().unwrap(); + self.db + .statekey_shortstatekey + .insert(&key, &shortstatekey.to_be_bytes()); + self.db + .shortstatekey_statekey + .insert(&shortstatekey.to_be_bytes(), &key); + + shortstatekey +} + +#[implement(Service)] +pub async fn get_eventid_from_short(&self, shorteventid: u64) -> Result> { + const BUFSIZE: usize = size_of::(); + + self.db + .shorteventid_eventid + .aqry::(&shorteventid) + .await + .deserialized() + .map_err(|e| err!(Database("Failed to find EventId from short {shorteventid:?}: {e:?}"))) +} + +#[implement(Service)] +pub async fn multi_get_eventid_from_short(&self, shorteventid: &[u64]) -> Vec>> { + const BUFSIZE: usize = size_of::(); + + let keys: Vec<[u8; BUFSIZE]> = shorteventid + .iter() + .map(|short| short.to_be_bytes()) + .collect(); + + self.db + .shorteventid_eventid + .get_batch_blocking(keys.iter()) + .into_iter() + .map(Deserialized::deserialized) + .collect() +} + +#[implement(Service)] +pub async fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)> { + const BUFSIZE: usize = size_of::(); - pub fn get_or_create_shortroomid(&self, room_id: &RoomId) -> Result { - self.db.get_or_create_shortroomid(room_id) + self.db + .shortstatekey_statekey + .aqry::(&shortstatekey) + .await + .deserialized() + .map_err(|e| { + err!(Database( + "Failed to find (StateEventType, state_key) from short {shortstatekey:?}: {e:?}" + )) + }) +} + +/// Returns (shortstatehash, already_existed) +#[implement(Service)] +pub async fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> (u64, bool) { + if let Ok(shortstatehash) = self + .db + .statehash_shortstatehash + .get(state_hash) + .await + .deserialized() + { + return (shortstatehash, true); } + + let shortstatehash = self.services.globals.next_count().unwrap(); + self.db + .statehash_shortstatehash + .insert(state_hash, &shortstatehash.to_be_bytes()); + + (shortstatehash, false) +} + +#[implement(Service)] +pub async fn get_shortroomid(&self, room_id: &RoomId) -> Result { + self.db.roomid_shortroomid.qry(room_id).await.deserialized() +} + +#[implement(Service)] +pub async fn get_or_create_shortroomid(&self, room_id: &RoomId) -> u64 { + self.db + .roomid_shortroomid + .get(room_id) + .await + .deserialized() + .unwrap_or_else(|_| { + let short = self.services.globals.next_count().unwrap(); + self.db + .roomid_shortroomid + .insert(room_id.as_bytes(), &short.to_be_bytes()); + short + }) } diff --git a/src/service/rooms/spaces/mod.rs b/src/service/rooms/spaces/mod.rs index 24d612d87..17fbf0ef0 100644 --- a/src/service/rooms/spaces/mod.rs +++ b/src/service/rooms/spaces/mod.rs @@ -7,7 +7,12 @@ use std::{ sync::Arc, }; -use conduit::{checked, debug, debug_info, err, utils::math::usize_from_f64, warn, Error, Result}; +use conduit::{ + checked, debug, debug_info, err, + utils::{math::usize_from_f64, IterStream}, + Error, Result, +}; +use futures::{StreamExt, TryFutureExt}; use lru_cache::LruCache; use ruma::{ api::{ @@ -211,12 +216,15 @@ impl Service { .as_ref() { return Ok(if let Some(cached) = cached { - if self.is_accessible_child( - current_room, - &cached.summary.join_rule, - &identifier, - &cached.summary.allowed_room_ids, - ) { + if self + .is_accessible_child( + current_room, + &cached.summary.join_rule, + &identifier, + &cached.summary.allowed_room_ids, + ) + .await + { Some(SummaryAccessibility::Accessible(Box::new(cached.summary.clone()))) } else { Some(SummaryAccessibility::Inaccessible) @@ -228,7 +236,9 @@ impl Service { Ok( if let Some(children_pdus) = self.get_stripped_space_child_events(current_room).await? { - let summary = self.get_room_summary(current_room, children_pdus, &identifier); + let summary = self + .get_room_summary(current_room, children_pdus, &identifier) + .await; if let Ok(summary) = summary { self.roomid_spacehierarchy_cache.lock().await.insert( current_room.clone(), @@ -322,12 +332,15 @@ impl Service { ); } } - if self.is_accessible_child( - current_room, - &response.room.join_rule, - &Identifier::UserId(user_id), - &response.room.allowed_room_ids, - ) { + if self + .is_accessible_child( + current_room, + &response.room.join_rule, + &Identifier::UserId(user_id), + &response.room.allowed_room_ids, + ) + .await + { return Ok(Some(SummaryAccessibility::Accessible(Box::new(summary.clone())))); } @@ -358,7 +371,7 @@ impl Service { } } - fn get_room_summary( + async fn get_room_summary( &self, current_room: &OwnedRoomId, children_state: Vec>, identifier: &Identifier<'_>, ) -> Result { @@ -367,48 +380,43 @@ impl Service { let join_rule = self .services .state_accessor - .room_state_get(room_id, &StateEventType::RoomJoinRules, "")? - .map(|s| { + .room_state_get(room_id, &StateEventType::RoomJoinRules, "") + .await + .map_or(JoinRule::Invite, |s| { serde_json::from_str(s.content.get()) .map(|c: RoomJoinRulesEventContent| c.join_rule) .map_err(|e| err!(Database(error!("Invalid room join rule event in database: {e}")))) - }) - .transpose()? - .unwrap_or(JoinRule::Invite); + .unwrap() + }); let allowed_room_ids = self .services .state_accessor .allowed_room_ids(join_rule.clone()); - if !self.is_accessible_child(current_room, &join_rule.clone().into(), identifier, &allowed_room_ids) { + if !self + .is_accessible_child(current_room, &join_rule.clone().into(), identifier, &allowed_room_ids) + .await + { debug!("User is not allowed to see room {room_id}"); // This error will be caught later return Err(Error::BadRequest(ErrorKind::forbidden(), "User is not allowed to see the room")); } - let join_rule = join_rule.into(); - Ok(SpaceHierarchyParentSummary { canonical_alias: self .services .state_accessor .get_canonical_alias(room_id) - .unwrap_or(None), - name: self - .services - .state_accessor - .get_name(room_id) - .unwrap_or(None), + .await + .ok(), + name: self.services.state_accessor.get_name(room_id).await.ok(), num_joined_members: self .services .state_cache .room_joined_count(room_id) - .unwrap_or_default() - .unwrap_or_else(|| { - warn!("Room {room_id} has no member count"); - 0 - }) + .await + .unwrap_or(0) .try_into() .expect("user count should not be that big"), room_id: room_id.to_owned(), @@ -416,18 +424,29 @@ impl Service { .services .state_accessor .get_room_topic(room_id) - .unwrap_or(None), - world_readable: self.services.state_accessor.is_world_readable(room_id)?, - guest_can_join: self.services.state_accessor.guest_can_join(room_id)?, + .await + .ok(), + world_readable: self + .services + .state_accessor + .is_world_readable(room_id) + .await, + guest_can_join: self.services.state_accessor.guest_can_join(room_id).await, avatar_url: self .services .state_accessor - .get_avatar(room_id)? + .get_avatar(room_id) + .await .into_option() .unwrap_or_default() .url, - join_rule, - room_type: self.services.state_accessor.get_room_type(room_id)?, + join_rule: join_rule.into(), + room_type: self + .services + .state_accessor + .get_room_type(room_id) + .await + .ok(), children_state, allowed_room_ids, }) @@ -474,21 +493,22 @@ impl Service { results.push(summary_to_chunk(*summary.clone())); } else { children = children - .into_iter() - .rev() - .skip_while(|(room, _)| { - if let Ok(short) = self.services.short.get_shortroomid(room) - { - short.as_ref() != short_room_ids.get(parents.len()) - } else { - false - } - }) - .collect::>() - // skip_while doesn't implement DoubleEndedIterator, which is needed for rev - .into_iter() - .rev() - .collect(); + .iter() + .rev() + .stream() + .skip_while(|(room, _)| { + self.services + .short + .get_shortroomid(room) + .map_ok(|short| Some(&short) != short_room_ids.get(parents.len())) + .unwrap_or_else(|_| false) + }) + .map(Clone::clone) + .collect::)>>() + .await + .into_iter() + .rev() + .collect(); if children.is_empty() { return Err(Error::BadRequest( @@ -531,7 +551,7 @@ impl Service { let mut short_room_ids = vec![]; for room in parents { - short_room_ids.push(self.services.short.get_or_create_shortroomid(&room)?); + short_room_ids.push(self.services.short.get_or_create_shortroomid(&room).await); } Some( @@ -554,7 +574,7 @@ impl Service { async fn get_stripped_space_child_events( &self, room_id: &RoomId, ) -> Result>>, Error> { - let Some(current_shortstatehash) = self.services.state.get_room_shortstatehash(room_id)? else { + let Ok(current_shortstatehash) = self.services.state.get_room_shortstatehash(room_id).await else { return Ok(None); }; @@ -562,10 +582,13 @@ impl Service { .services .state_accessor .state_full_ids(current_shortstatehash) - .await?; + .await + .map_err(|e| err!(Database("State in space not found: {e}")))?; + let mut children_pdus = Vec::new(); for (key, id) in state { - let (event_type, state_key) = self.services.short.get_statekey_from_short(key)?; + let (event_type, state_key) = self.services.short.get_statekey_from_short(key).await?; + if event_type != StateEventType::SpaceChild { continue; } @@ -573,8 +596,9 @@ impl Service { let pdu = self .services .timeline - .get_pdu(&id)? - .ok_or_else(|| Error::bad_database("Event in space state not found"))?; + .get_pdu(&id) + .await + .map_err(|e| err!(Database("Event {id:?} in space state not found: {e:?}")))?; if serde_json::from_str::(pdu.content.get()) .ok() @@ -593,7 +617,7 @@ impl Service { } /// With the given identifier, checks if a room is accessable - fn is_accessible_child( + async fn is_accessible_child( &self, current_room: &OwnedRoomId, join_rule: &SpaceRoomJoinRule, identifier: &Identifier<'_>, allowed_room_ids: &Vec, ) -> bool { @@ -607,6 +631,7 @@ impl Service { .services .event_handler .acl_check(server_name, room_id) + .await .is_err() { return false; @@ -617,12 +642,11 @@ impl Service { .services .state_cache .is_joined(user_id, current_room) - .unwrap_or_default() - || self - .services - .state_cache - .is_invited(user_id, current_room) - .unwrap_or_default() + .await || self + .services + .state_cache + .is_invited(user_id, current_room) + .await { return true; } @@ -633,22 +657,12 @@ impl Service { for room in allowed_room_ids { match identifier { Identifier::UserId(user) => { - if self - .services - .state_cache - .is_joined(user, room) - .unwrap_or_default() - { + if self.services.state_cache.is_joined(user, room).await { return true; } }, Identifier::ServerName(server) => { - if self - .services - .state_cache - .server_in_room(server, room) - .unwrap_or_default() - { + if self.services.state_cache.server_in_room(server, room).await { return true; } }, diff --git a/src/service/rooms/state/data.rs b/src/service/rooms/state/data.rs index 3c110afc6..3072e3c65 100644 --- a/src/service/rooms/state/data.rs +++ b/src/service/rooms/state/data.rs @@ -1,34 +1,31 @@ -use std::{collections::HashSet, sync::Arc}; +use std::sync::Arc; -use conduit::{utils, Error, Result}; -use database::{Database, Map}; -use ruma::{EventId, OwnedEventId, RoomId}; +use conduit::{ + utils::{stream::TryIgnore, ReadyExt}, + Result, +}; +use database::{Database, Deserialized, Interfix, Map}; +use ruma::{OwnedEventId, RoomId}; use super::RoomMutexGuard; pub(super) struct Data { shorteventid_shortstatehash: Arc, - roomid_pduleaves: Arc, roomid_shortstatehash: Arc, + pub(super) roomid_pduleaves: Arc, } impl Data { pub(super) fn new(db: &Arc) -> Self { Self { shorteventid_shortstatehash: db["shorteventid_shortstatehash"].clone(), - roomid_pduleaves: db["roomid_pduleaves"].clone(), roomid_shortstatehash: db["roomid_shortstatehash"].clone(), + roomid_pduleaves: db["roomid_pduleaves"].clone(), } } - pub(super) fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result> { - self.roomid_shortstatehash - .get(room_id.as_bytes())? - .map_or(Ok(None), |bytes| { - Ok(Some(utils::u64_from_bytes(&bytes).map_err(|_| { - Error::bad_database("Invalid shortstatehash in roomid_shortstatehash") - })?)) - }) + pub(super) async fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result { + self.roomid_shortstatehash.get(room_id).await.deserialized() } #[inline] @@ -37,53 +34,35 @@ impl Data { room_id: &RoomId, new_shortstatehash: u64, _mutex_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex - ) -> Result<()> { + ) { self.roomid_shortstatehash - .insert(room_id.as_bytes(), &new_shortstatehash.to_be_bytes())?; - Ok(()) + .insert(room_id.as_bytes(), &new_shortstatehash.to_be_bytes()); } - pub(super) fn set_event_state(&self, shorteventid: u64, shortstatehash: u64) -> Result<()> { + pub(super) fn set_event_state(&self, shorteventid: u64, shortstatehash: u64) { self.shorteventid_shortstatehash - .insert(&shorteventid.to_be_bytes(), &shortstatehash.to_be_bytes())?; - Ok(()) + .insert(&shorteventid.to_be_bytes(), &shortstatehash.to_be_bytes()); } - pub(super) fn get_forward_extremities(&self, room_id: &RoomId) -> Result>> { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xFF); - - self.roomid_pduleaves - .scan_prefix(prefix) - .map(|(_, bytes)| { - EventId::parse_arc( - utils::string_from_bytes(&bytes) - .map_err(|_| Error::bad_database("EventID in roomid_pduleaves is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("EventId in roomid_pduleaves is invalid.")) - }) - .collect() - } - - pub(super) fn set_forward_extremities( + pub(super) async fn set_forward_extremities( &self, room_id: &RoomId, event_ids: Vec, _mutex_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex - ) -> Result<()> { + ) { + let prefix = (room_id, Interfix); + self.roomid_pduleaves + .keys_raw_prefix(&prefix) + .ignore_err() + .ready_for_each(|key| self.roomid_pduleaves.remove(key)) + .await; + let mut prefix = room_id.as_bytes().to_vec(); prefix.push(0xFF); - - for (key, _) in self.roomid_pduleaves.scan_prefix(prefix.clone()) { - self.roomid_pduleaves.remove(&key)?; - } - for event_id in event_ids { let mut key = prefix.clone(); key.extend_from_slice(event_id.as_bytes()); - self.roomid_pduleaves.insert(&key, event_id.as_bytes())?; + self.roomid_pduleaves.insert(&key, event_id.as_bytes()); } - - Ok(()) } } diff --git a/src/service/rooms/state/mod.rs b/src/service/rooms/state/mod.rs index cb219bc03..c7f6605c7 100644 --- a/src/service/rooms/state/mod.rs +++ b/src/service/rooms/state/mod.rs @@ -7,12 +7,14 @@ use std::{ }; use conduit::{ - utils::{calculate_hash, MutexMap, MutexMapGuard}, - warn, Error, PduEvent, Result, + err, + utils::{calculate_hash, stream::TryIgnore, IterStream, MutexMap, MutexMapGuard}, + warn, PduEvent, Result, }; use data::Data; +use database::{Ignore, Interfix}; +use futures::{pin_mut, FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt}; use ruma::{ - api::client::error::ErrorKind, events::{ room::{create::RoomCreateEventContent, member::RoomMemberEventContent}, AnyStrippedStateEvent, StateEventType, TimelineEventType, @@ -81,14 +83,16 @@ impl Service { _statediffremoved: Arc>, state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex ) -> Result<()> { - for event_id in statediffnew.iter().filter_map(|new| { + let event_ids = statediffnew.iter().stream().filter_map(|new| { self.services .state_compressor .parse_compressed_state_event(new) - .ok() - .map(|(_, id)| id) - }) { - let Some(pdu) = self.services.timeline.get_pdu_json(&event_id)? else { + .map_ok_or_else(|_| None, |(_, event_id)| Some(event_id)) + }); + + pin_mut!(event_ids); + while let Some(event_id) = event_ids.next().await { + let Ok(pdu) = self.services.timeline.get_pdu_json(&event_id).await else { continue; }; @@ -113,15 +117,10 @@ impl Service { continue; }; - self.services.state_cache.update_membership( - room_id, - &user_id, - membership_event, - &pdu.sender, - None, - None, - false, - )?; + self.services + .state_cache + .update_membership(room_id, &user_id, membership_event, &pdu.sender, None, None, false) + .await?; }, TimelineEventType::SpaceChild => { self.services @@ -135,10 +134,9 @@ impl Service { } } - self.services.state_cache.update_joined_count(room_id)?; + self.services.state_cache.update_joined_count(room_id).await; - self.db - .set_room_state(room_id, shortstatehash, state_lock)?; + self.db.set_room_state(room_id, shortstatehash, state_lock); Ok(()) } @@ -148,12 +146,16 @@ impl Service { /// This adds all current state events (not including the incoming event) /// to `stateid_pduid` and adds the incoming event to `eventid_statehash`. #[tracing::instrument(skip(self, state_ids_compressed), level = "debug")] - pub fn set_event_state( + pub async fn set_event_state( &self, event_id: &EventId, room_id: &RoomId, state_ids_compressed: Arc>, ) -> Result { - let shorteventid = self.services.short.get_or_create_shorteventid(event_id)?; + let shorteventid = self + .services + .short + .get_or_create_shorteventid(event_id) + .await; - let previous_shortstatehash = self.db.get_room_shortstatehash(room_id)?; + let previous_shortstatehash = self.db.get_room_shortstatehash(room_id).await; let state_hash = calculate_hash( &state_ids_compressed @@ -165,13 +167,18 @@ impl Service { let (shortstatehash, already_existed) = self .services .short - .get_or_create_shortstatehash(&state_hash)?; + .get_or_create_shortstatehash(&state_hash) + .await; if !already_existed { - let states_parents = previous_shortstatehash.map_or_else( - || Ok(Vec::new()), - |p| self.services.state_compressor.load_shortstatehash_info(p), - )?; + let states_parents = if let Ok(p) = previous_shortstatehash { + self.services + .state_compressor + .load_shortstatehash_info(p) + .await? + } else { + Vec::new() + }; let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() { let statediffnew: HashSet<_> = state_ids_compressed @@ -198,7 +205,7 @@ impl Service { )?; } - self.db.set_event_state(shorteventid, shortstatehash)?; + self.db.set_event_state(shorteventid, shortstatehash); Ok(shortstatehash) } @@ -208,34 +215,40 @@ impl Service { /// This adds all current state events (not including the incoming event) /// to `stateid_pduid` and adds the incoming event to `eventid_statehash`. #[tracing::instrument(skip(self, new_pdu), level = "debug")] - pub fn append_to_state(&self, new_pdu: &PduEvent) -> Result { + pub async fn append_to_state(&self, new_pdu: &PduEvent) -> Result { let shorteventid = self .services .short - .get_or_create_shorteventid(&new_pdu.event_id)?; + .get_or_create_shorteventid(&new_pdu.event_id) + .await; - let previous_shortstatehash = self.get_room_shortstatehash(&new_pdu.room_id)?; + let previous_shortstatehash = self.get_room_shortstatehash(&new_pdu.room_id).await; - if let Some(p) = previous_shortstatehash { - self.db.set_event_state(shorteventid, p)?; + if let Ok(p) = previous_shortstatehash { + self.db.set_event_state(shorteventid, p); } if let Some(state_key) = &new_pdu.state_key { - let states_parents = previous_shortstatehash.map_or_else( - || Ok(Vec::new()), - #[inline] - |p| self.services.state_compressor.load_shortstatehash_info(p), - )?; + let states_parents = if let Ok(p) = previous_shortstatehash { + self.services + .state_compressor + .load_shortstatehash_info(p) + .await? + } else { + Vec::new() + }; let shortstatekey = self .services .short - .get_or_create_shortstatekey(&new_pdu.kind.to_string().into(), state_key)?; + .get_or_create_shortstatekey(&new_pdu.kind.to_string().into(), state_key) + .await; let new = self .services .state_compressor - .compress_state_event(shortstatekey, &new_pdu.event_id)?; + .compress_state_event(shortstatekey, &new_pdu.event_id) + .await; let replaces = states_parents .last() @@ -276,49 +289,55 @@ impl Service { } #[tracing::instrument(skip(self, invite_event), level = "debug")] - pub fn calculate_invite_state(&self, invite_event: &PduEvent) -> Result>> { + pub async fn calculate_invite_state(&self, invite_event: &PduEvent) -> Result>> { let mut state = Vec::new(); // Add recommended events - if let Some(e) = - self.services - .state_accessor - .room_state_get(&invite_event.room_id, &StateEventType::RoomCreate, "")? + if let Ok(e) = self + .services + .state_accessor + .room_state_get(&invite_event.room_id, &StateEventType::RoomCreate, "") + .await { state.push(e.to_stripped_state_event()); } - if let Some(e) = - self.services - .state_accessor - .room_state_get(&invite_event.room_id, &StateEventType::RoomJoinRules, "")? + if let Ok(e) = self + .services + .state_accessor + .room_state_get(&invite_event.room_id, &StateEventType::RoomJoinRules, "") + .await { state.push(e.to_stripped_state_event()); } - if let Some(e) = self.services.state_accessor.room_state_get( - &invite_event.room_id, - &StateEventType::RoomCanonicalAlias, - "", - )? { + if let Ok(e) = self + .services + .state_accessor + .room_state_get(&invite_event.room_id, &StateEventType::RoomCanonicalAlias, "") + .await + { state.push(e.to_stripped_state_event()); } - if let Some(e) = - self.services - .state_accessor - .room_state_get(&invite_event.room_id, &StateEventType::RoomAvatar, "")? + if let Ok(e) = self + .services + .state_accessor + .room_state_get(&invite_event.room_id, &StateEventType::RoomAvatar, "") + .await { state.push(e.to_stripped_state_event()); } - if let Some(e) = - self.services - .state_accessor - .room_state_get(&invite_event.room_id, &StateEventType::RoomName, "")? + if let Ok(e) = self + .services + .state_accessor + .room_state_get(&invite_event.room_id, &StateEventType::RoomName, "") + .await { state.push(e.to_stripped_state_event()); } - if let Some(e) = self.services.state_accessor.room_state_get( - &invite_event.room_id, - &StateEventType::RoomMember, - invite_event.sender.as_str(), - )? { + if let Ok(e) = self + .services + .state_accessor + .room_state_get(&invite_event.room_id, &StateEventType::RoomMember, invite_event.sender.as_str()) + .await + { state.push(e.to_stripped_state_event()); } @@ -333,101 +352,108 @@ impl Service { room_id: &RoomId, shortstatehash: u64, mutex_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex - ) -> Result<()> { - self.db.set_room_state(room_id, shortstatehash, mutex_lock) + ) { + self.db.set_room_state(room_id, shortstatehash, mutex_lock); } /// Returns the room's version. #[tracing::instrument(skip(self), level = "debug")] - pub fn get_room_version(&self, room_id: &RoomId) -> Result { - let create_event = self - .services + pub async fn get_room_version(&self, room_id: &RoomId) -> Result { + self.services .state_accessor - .room_state_get(room_id, &StateEventType::RoomCreate, "")?; - - let create_event_content: RoomCreateEventContent = create_event - .as_ref() - .map(|create_event| { - serde_json::from_str(create_event.content.get()).map_err(|e| { - warn!("Invalid create event: {}", e); - Error::bad_database("Invalid create event in db.") - }) - }) - .transpose()? - .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "No create event found"))?; - - Ok(create_event_content.room_version) + .room_state_get_content(room_id, &StateEventType::RoomCreate, "") + .await + .map(|content: RoomCreateEventContent| content.room_version) + .map_err(|e| err!(Request(NotFound("No create event found: {e:?}")))) } #[inline] - pub fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result> { - self.db.get_room_shortstatehash(room_id) + pub async fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result { + self.db.get_room_shortstatehash(room_id).await } - pub fn get_forward_extremities(&self, room_id: &RoomId) -> Result>> { - self.db.get_forward_extremities(room_id) + pub fn get_forward_extremities<'a>(&'a self, room_id: &'a RoomId) -> impl Stream + Send + '_ { + let prefix = (room_id, Interfix); + + self.db + .roomid_pduleaves + .keys_prefix(&prefix) + .map_ok(|(_, event_id): (Ignore, &EventId)| event_id) + .ignore_err() } - pub fn set_forward_extremities( + pub async fn set_forward_extremities( &self, room_id: &RoomId, event_ids: Vec, state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex - ) -> Result<()> { + ) { self.db .set_forward_extremities(room_id, event_ids, state_lock) + .await; } /// This fetches auth events from the current state. #[tracing::instrument(skip(self), level = "debug")] - pub fn get_auth_events( + pub async fn get_auth_events( &self, room_id: &RoomId, kind: &TimelineEventType, sender: &UserId, state_key: Option<&str>, content: &serde_json::value::RawValue, ) -> Result>> { - let Some(shortstatehash) = self.get_room_shortstatehash(room_id)? else { + let Ok(shortstatehash) = self.get_room_shortstatehash(room_id).await else { return Ok(HashMap::new()); }; - let auth_events = - state_res::auth_types_for_event(kind, sender, state_key, content).expect("content is a valid JSON object"); + let auth_events = state_res::auth_types_for_event(kind, sender, state_key, content)?; - let mut sauthevents = auth_events - .into_iter() + let mut sauthevents: HashMap<_, _> = auth_events + .iter() + .stream() .filter_map(|(event_type, state_key)| { self.services .short - .get_shortstatekey(&event_type.to_string().into(), &state_key) - .ok() - .flatten() - .map(|s| (s, (event_type, state_key))) + .get_shortstatekey(event_type, state_key) + .map_ok(move |s| (s, (event_type, state_key))) + .map(Result::ok) }) - .collect::>(); + .collect() + .await; let full_state = self .services .state_compressor - .load_shortstatehash_info(shortstatehash)? + .load_shortstatehash_info(shortstatehash) + .await + .map_err(|e| { + err!(Database( + "Missing shortstatehash info for {room_id:?} at {shortstatehash:?}: {e:?}" + )) + })? .pop() .expect("there is always one layer") .1; - Ok(full_state - .iter() - .filter_map(|compressed| { - self.services - .state_compressor - .parse_compressed_state_event(compressed) - .ok() - }) - .filter_map(|(shortstatekey, event_id)| sauthevents.remove(&shortstatekey).map(|k| (k, event_id))) - .filter_map(|(k, event_id)| { - self.services - .timeline - .get_pdu(&event_id) - .ok() - .flatten() - .map(|pdu| (k, pdu)) - }) - .collect()) + let mut ret = HashMap::new(); + for compressed in full_state.iter() { + let Ok((shortstatekey, event_id)) = self + .services + .state_compressor + .parse_compressed_state_event(compressed) + .await + else { + continue; + }; + + let Some((ty, state_key)) = sauthevents.remove(&shortstatekey) else { + continue; + }; + + let Ok(pdu) = self.services.timeline.get_pdu(&event_id).await else { + continue; + }; + + ret.insert((ty.to_owned(), state_key.to_owned()), pdu); + } + + Ok(ret) } } diff --git a/src/service/rooms/state_accessor/data.rs b/src/service/rooms/state_accessor/data.rs index 4c85148db..adc26f000 100644 --- a/src/service/rooms/state_accessor/data.rs +++ b/src/service/rooms/state_accessor/data.rs @@ -1,7 +1,8 @@ use std::{collections::HashMap, sync::Arc}; -use conduit::{utils, Error, PduEvent, Result}; -use database::Map; +use conduit::{err, PduEvent, Result}; +use database::{Deserialized, Map}; +use futures::TryFutureExt; use ruma::{events::StateEventType, EventId, RoomId}; use crate::{rooms, Dep}; @@ -39,17 +40,22 @@ impl Data { let full_state = self .services .state_compressor - .load_shortstatehash_info(shortstatehash)? + .load_shortstatehash_info(shortstatehash) + .await + .map_err(|e| err!(Database("Missing state IDs: {e}")))? .pop() .expect("there is always one layer") .1; + let mut result = HashMap::new(); let mut i: u8 = 0; for compressed in full_state.iter() { let parsed = self .services .state_compressor - .parse_compressed_state_event(compressed)?; + .parse_compressed_state_event(compressed) + .await?; + result.insert(parsed.0, parsed.1); i = i.wrapping_add(1); @@ -57,6 +63,7 @@ impl Data { tokio::task::yield_now().await; } } + Ok(result) } @@ -67,7 +74,8 @@ impl Data { let full_state = self .services .state_compressor - .load_shortstatehash_info(shortstatehash)? + .load_shortstatehash_info(shortstatehash) + .await? .pop() .expect("there is always one layer") .1; @@ -78,18 +86,13 @@ impl Data { let (_, eventid) = self .services .state_compressor - .parse_compressed_state_event(compressed)?; - if let Some(pdu) = self.services.timeline.get_pdu(&eventid)? { - result.insert( - ( - pdu.kind.to_string().into(), - pdu.state_key - .as_ref() - .ok_or_else(|| Error::bad_database("State event has no state key."))? - .clone(), - ), - pdu, - ); + .parse_compressed_state_event(compressed) + .await?; + + if let Ok(pdu) = self.services.timeline.get_pdu(&eventid).await { + if let Some(state_key) = pdu.state_key.as_ref() { + result.insert((pdu.kind.to_string().into(), state_key.clone()), pdu); + } } i = i.wrapping_add(1); @@ -101,61 +104,63 @@ impl Data { Ok(result) } - /// Returns a single PDU from `room_id` with key (`event_type`, - /// `state_key`). + /// Returns a single PDU from `room_id` with key (`event_type`,`state_key`). #[allow(clippy::unused_self)] - pub(super) fn state_get_id( + pub(super) async fn state_get_id( &self, shortstatehash: u64, event_type: &StateEventType, state_key: &str, - ) -> Result>> { - let Some(shortstatekey) = self + ) -> Result> { + let shortstatekey = self .services .short - .get_shortstatekey(event_type, state_key)? - else { - return Ok(None); - }; + .get_shortstatekey(event_type, state_key) + .await?; + let full_state = self .services .state_compressor - .load_shortstatehash_info(shortstatehash)? + .load_shortstatehash_info(shortstatehash) + .await + .map_err(|e| err!(Database(error!(?event_type, ?state_key, "Missing state: {e:?}"))))? .pop() .expect("there is always one layer") .1; - Ok(full_state + + let compressed = full_state .iter() .find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes())) - .and_then(|compressed| { - self.services - .state_compressor - .parse_compressed_state_event(compressed) - .ok() - .map(|(_, id)| id) - })) + .ok_or(err!(Database("No shortstatekey in compressed state")))?; + + self.services + .state_compressor + .parse_compressed_state_event(compressed) + .map_ok(|(_, id)| id) + .map_err(|e| { + err!(Database(error!( + ?event_type, + ?state_key, + ?shortstatekey, + "Failed to parse compressed: {e:?}" + ))) + }) + .await } - /// Returns a single PDU from `room_id` with key (`event_type`, - /// `state_key`). - pub(super) fn state_get( + /// Returns a single PDU from `room_id` with key (`event_type`,`state_key`). + pub(super) async fn state_get( &self, shortstatehash: u64, event_type: &StateEventType, state_key: &str, - ) -> Result>> { - self.state_get_id(shortstatehash, event_type, state_key)? - .map_or(Ok(None), |event_id| self.services.timeline.get_pdu(&event_id)) + ) -> Result> { + self.state_get_id(shortstatehash, event_type, state_key) + .and_then(|event_id| async move { self.services.timeline.get_pdu(&event_id).await }) + .await } /// Returns the state hash for this pdu. - pub(super) fn pdu_shortstatehash(&self, event_id: &EventId) -> Result> { + pub(super) async fn pdu_shortstatehash(&self, event_id: &EventId) -> Result { self.eventid_shorteventid - .get(event_id.as_bytes())? - .map_or(Ok(None), |shorteventid| { - self.shorteventid_shortstatehash - .get(&shorteventid)? - .map(|bytes| { - utils::u64_from_bytes(&bytes).map_err(|_| { - Error::bad_database("Invalid shortstatehash bytes in shorteventid_shortstatehash") - }) - }) - .transpose() - }) + .get(event_id) + .and_then(|shorteventid| self.shorteventid_shortstatehash.get(&shorteventid)) + .await + .deserialized() } /// Returns the full room state. @@ -163,34 +168,33 @@ impl Data { pub(super) async fn room_state_full( &self, room_id: &RoomId, ) -> Result>> { - if let Some(current_shortstatehash) = self.services.state.get_room_shortstatehash(room_id)? { - self.state_full(current_shortstatehash).await - } else { - Ok(HashMap::new()) - } + self.services + .state + .get_room_shortstatehash(room_id) + .and_then(|shortstatehash| self.state_full(shortstatehash)) + .map_err(|e| err!(Database("Missing state for {room_id:?}: {e:?}"))) + .await } - /// Returns a single PDU from `room_id` with key (`event_type`, - /// `state_key`). - pub(super) fn room_state_get_id( + /// Returns a single PDU from `room_id` with key (`event_type`,`state_key`). + pub(super) async fn room_state_get_id( &self, room_id: &RoomId, event_type: &StateEventType, state_key: &str, - ) -> Result>> { - if let Some(current_shortstatehash) = self.services.state.get_room_shortstatehash(room_id)? { - self.state_get_id(current_shortstatehash, event_type, state_key) - } else { - Ok(None) - } + ) -> Result> { + self.services + .state + .get_room_shortstatehash(room_id) + .and_then(|shortstatehash| self.state_get_id(shortstatehash, event_type, state_key)) + .await } - /// Returns a single PDU from `room_id` with key (`event_type`, - /// `state_key`). - pub(super) fn room_state_get( + /// Returns a single PDU from `room_id` with key (`event_type`,`state_key`). + pub(super) async fn room_state_get( &self, room_id: &RoomId, event_type: &StateEventType, state_key: &str, - ) -> Result>> { - if let Some(current_shortstatehash) = self.services.state.get_room_shortstatehash(room_id)? { - self.state_get(current_shortstatehash, event_type, state_key) - } else { - Ok(None) - } + ) -> Result> { + self.services + .state + .get_room_shortstatehash(room_id) + .and_then(|shortstatehash| self.state_get(shortstatehash, event_type, state_key)) + .await } } diff --git a/src/service/rooms/state_accessor/mod.rs b/src/service/rooms/state_accessor/mod.rs index 58fa31b3d..4c28483cb 100644 --- a/src/service/rooms/state_accessor/mod.rs +++ b/src/service/rooms/state_accessor/mod.rs @@ -6,8 +6,13 @@ use std::{ sync::{Arc, Mutex as StdMutex, Mutex}, }; -use conduit::{err, error, pdu::PduBuilder, utils::math::usize_from_f64, warn, Error, PduEvent, Result}; -use data::Data; +use conduit::{ + err, error, + pdu::PduBuilder, + utils::{math::usize_from_f64, ReadyExt}, + Error, PduEvent, Result, +}; +use futures::StreamExt; use lru_cache::LruCache; use ruma::{ events::{ @@ -31,8 +36,10 @@ use ruma::{ EventEncryptionAlgorithm, EventId, OwnedRoomAliasId, OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, }; +use serde::Deserialize; use serde_json::value::to_raw_value; +use self::data::Data; use crate::{rooms, rooms::state::RoomMutexGuard, Dep}; pub struct Service { @@ -99,54 +106,58 @@ impl Service { /// Returns a single PDU from `room_id` with key (`event_type`, /// `state_key`). #[tracing::instrument(skip(self), level = "debug")] - pub fn state_get_id( + pub async fn state_get_id( &self, shortstatehash: u64, event_type: &StateEventType, state_key: &str, - ) -> Result>> { - self.db.state_get_id(shortstatehash, event_type, state_key) + ) -> Result> { + self.db + .state_get_id(shortstatehash, event_type, state_key) + .await } /// Returns a single PDU from `room_id` with key (`event_type`, /// `state_key`). #[inline] - pub fn state_get( + pub async fn state_get( &self, shortstatehash: u64, event_type: &StateEventType, state_key: &str, - ) -> Result>> { - self.db.state_get(shortstatehash, event_type, state_key) + ) -> Result> { + self.db + .state_get(shortstatehash, event_type, state_key) + .await } /// Get membership for given user in state - fn user_membership(&self, shortstatehash: u64, user_id: &UserId) -> Result { - self.state_get(shortstatehash, &StateEventType::RoomMember, user_id.as_str())? - .map_or(Ok(MembershipState::Leave), |s| { + async fn user_membership(&self, shortstatehash: u64, user_id: &UserId) -> MembershipState { + self.state_get(shortstatehash, &StateEventType::RoomMember, user_id.as_str()) + .await + .map_or(MembershipState::Leave, |s| { serde_json::from_str(s.content.get()) .map(|c: RoomMemberEventContent| c.membership) .map_err(|_| Error::bad_database("Invalid room membership event in database.")) + .unwrap() }) } /// The user was a joined member at this state (potentially in the past) #[inline] - fn user_was_joined(&self, shortstatehash: u64, user_id: &UserId) -> bool { - self.user_membership(shortstatehash, user_id) - .is_ok_and(|s| s == MembershipState::Join) - // Return sensible default, i.e. - // false + async fn user_was_joined(&self, shortstatehash: u64, user_id: &UserId) -> bool { + self.user_membership(shortstatehash, user_id).await == MembershipState::Join } /// The user was an invited or joined room member at this state (potentially /// in the past) #[inline] - fn user_was_invited(&self, shortstatehash: u64, user_id: &UserId) -> bool { - self.user_membership(shortstatehash, user_id) - .is_ok_and(|s| s == MembershipState::Join || s == MembershipState::Invite) - // Return sensible default, i.e. false + async fn user_was_invited(&self, shortstatehash: u64, user_id: &UserId) -> bool { + let s = self.user_membership(shortstatehash, user_id).await; + s == MembershipState::Join || s == MembershipState::Invite } /// Whether a server is allowed to see an event through federation, based on /// the room's history_visibility at that event's state. #[tracing::instrument(skip(self, origin, room_id, event_id))] - pub fn server_can_see_event(&self, origin: &ServerName, room_id: &RoomId, event_id: &EventId) -> Result { - let Some(shortstatehash) = self.pdu_shortstatehash(event_id)? else { + pub async fn server_can_see_event( + &self, origin: &ServerName, room_id: &RoomId, event_id: &EventId, + ) -> Result { + let Ok(shortstatehash) = self.pdu_shortstatehash(event_id).await else { return Ok(true); }; @@ -160,8 +171,9 @@ impl Service { } let history_visibility = self - .state_get(shortstatehash, &StateEventType::RoomHistoryVisibility, "")? - .map_or(Ok(HistoryVisibility::Shared), |s| { + .state_get(shortstatehash, &StateEventType::RoomHistoryVisibility, "") + .await + .map_or(HistoryVisibility::Shared, |s| { serde_json::from_str(s.content.get()) .map(|c: RoomHistoryVisibilityEventContent| c.history_visibility) .map_err(|e| { @@ -171,25 +183,28 @@ impl Service { ); Error::bad_database("Invalid history visibility event in database.") }) - }) - .unwrap_or(HistoryVisibility::Shared); + .unwrap() + }); - let mut current_server_members = self + let current_server_members = self .services .state_cache .room_members(room_id) - .filter_map(Result::ok) - .filter(|member| member.server_name() == origin); + .ready_filter(|member| member.server_name() == origin); let visibility = match history_visibility { HistoryVisibility::WorldReadable | HistoryVisibility::Shared => true, HistoryVisibility::Invited => { // Allow if any member on requesting server was AT LEAST invited, else deny - current_server_members.any(|member| self.user_was_invited(shortstatehash, &member)) + current_server_members + .any(|member| self.user_was_invited(shortstatehash, member)) + .await }, HistoryVisibility::Joined => { // Allow if any member on requested server was joined, else deny - current_server_members.any(|member| self.user_was_joined(shortstatehash, &member)) + current_server_members + .any(|member| self.user_was_joined(shortstatehash, member)) + .await }, _ => { error!("Unknown history visibility {history_visibility}"); @@ -208,9 +223,9 @@ impl Service { /// Whether a user is allowed to see an event, based on /// the room's history_visibility at that event's state. #[tracing::instrument(skip(self, user_id, room_id, event_id))] - pub fn user_can_see_event(&self, user_id: &UserId, room_id: &RoomId, event_id: &EventId) -> Result { - let Some(shortstatehash) = self.pdu_shortstatehash(event_id)? else { - return Ok(true); + pub async fn user_can_see_event(&self, user_id: &UserId, room_id: &RoomId, event_id: &EventId) -> bool { + let Ok(shortstatehash) = self.pdu_shortstatehash(event_id).await else { + return true; }; if let Some(visibility) = self @@ -219,14 +234,15 @@ impl Service { .unwrap() .get_mut(&(user_id.to_owned(), shortstatehash)) { - return Ok(*visibility); + return *visibility; } - let currently_member = self.services.state_cache.is_joined(user_id, room_id)?; + let currently_member = self.services.state_cache.is_joined(user_id, room_id).await; let history_visibility = self - .state_get(shortstatehash, &StateEventType::RoomHistoryVisibility, "")? - .map_or(Ok(HistoryVisibility::Shared), |s| { + .state_get(shortstatehash, &StateEventType::RoomHistoryVisibility, "") + .await + .map_or(HistoryVisibility::Shared, |s| { serde_json::from_str(s.content.get()) .map(|c: RoomHistoryVisibilityEventContent| c.history_visibility) .map_err(|e| { @@ -236,19 +252,19 @@ impl Service { ); Error::bad_database("Invalid history visibility event in database.") }) - }) - .unwrap_or(HistoryVisibility::Shared); + .unwrap() + }); let visibility = match history_visibility { HistoryVisibility::WorldReadable => true, HistoryVisibility::Shared => currently_member, HistoryVisibility::Invited => { // Allow if any member on requesting server was AT LEAST invited, else deny - self.user_was_invited(shortstatehash, user_id) + self.user_was_invited(shortstatehash, user_id).await }, HistoryVisibility::Joined => { // Allow if any member on requested server was joined, else deny - self.user_was_joined(shortstatehash, user_id) + self.user_was_joined(shortstatehash, user_id).await }, _ => { error!("Unknown history visibility {history_visibility}"); @@ -261,17 +277,18 @@ impl Service { .unwrap() .insert((user_id.to_owned(), shortstatehash), visibility); - Ok(visibility) + visibility } /// Whether a user is allowed to see an event, based on /// the room's history_visibility at that event's state. #[tracing::instrument(skip(self, user_id, room_id))] - pub fn user_can_see_state_events(&self, user_id: &UserId, room_id: &RoomId) -> Result { - let currently_member = self.services.state_cache.is_joined(user_id, room_id)?; + pub async fn user_can_see_state_events(&self, user_id: &UserId, room_id: &RoomId) -> bool { + let currently_member = self.services.state_cache.is_joined(user_id, room_id).await; let history_visibility = self - .room_state_get(room_id, &StateEventType::RoomHistoryVisibility, "")? + .room_state_get(room_id, &StateEventType::RoomHistoryVisibility, "") + .await .map_or(Ok(HistoryVisibility::Shared), |s| { serde_json::from_str(s.content.get()) .map(|c: RoomHistoryVisibilityEventContent| c.history_visibility) @@ -285,11 +302,13 @@ impl Service { }) .unwrap_or(HistoryVisibility::Shared); - Ok(currently_member || history_visibility == HistoryVisibility::WorldReadable) + currently_member || history_visibility == HistoryVisibility::WorldReadable } /// Returns the state hash for this pdu. - pub fn pdu_shortstatehash(&self, event_id: &EventId) -> Result> { self.db.pdu_shortstatehash(event_id) } + pub async fn pdu_shortstatehash(&self, event_id: &EventId) -> Result { + self.db.pdu_shortstatehash(event_id).await + } /// Returns the full room state. #[tracing::instrument(skip(self), level = "debug")] @@ -300,47 +319,61 @@ impl Service { /// Returns a single PDU from `room_id` with key (`event_type`, /// `state_key`). #[tracing::instrument(skip(self), level = "debug")] - pub fn room_state_get_id( + pub async fn room_state_get_id( &self, room_id: &RoomId, event_type: &StateEventType, state_key: &str, - ) -> Result>> { - self.db.room_state_get_id(room_id, event_type, state_key) + ) -> Result> { + self.db + .room_state_get_id(room_id, event_type, state_key) + .await } /// Returns a single PDU from `room_id` with key (`event_type`, /// `state_key`). #[tracing::instrument(skip(self), level = "debug")] - pub fn room_state_get( + pub async fn room_state_get( &self, room_id: &RoomId, event_type: &StateEventType, state_key: &str, - ) -> Result>> { - self.db.room_state_get(room_id, event_type, state_key) + ) -> Result> { + self.db.room_state_get(room_id, event_type, state_key).await } - pub fn get_name(&self, room_id: &RoomId) -> Result> { - self.room_state_get(room_id, &StateEventType::RoomName, "")? - .map_or(Ok(None), |s| { - Ok(serde_json::from_str(s.content.get()).map_or_else(|_| None, |c: RoomNameEventContent| Some(c.name))) - }) + /// Returns a single PDU from `room_id` with key (`event_type`,`state_key`). + pub async fn room_state_get_content( + &self, room_id: &RoomId, event_type: &StateEventType, state_key: &str, + ) -> Result + where + T: for<'de> Deserialize<'de> + Send, + { + use serde_json::from_str; + + self.room_state_get(room_id, event_type, state_key) + .await + .and_then(|event| from_str::(event.content.get()).map_err(Into::into)) } - pub fn get_avatar(&self, room_id: &RoomId) -> Result> { - self.room_state_get(room_id, &StateEventType::RoomAvatar, "")? - .map_or(Ok(ruma::JsOption::Undefined), |s| { + pub async fn get_name(&self, room_id: &RoomId) -> Result { + self.room_state_get_content(room_id, &StateEventType::RoomName, "") + .await + .map(|c: RoomNameEventContent| c.name) + } + + pub async fn get_avatar(&self, room_id: &RoomId) -> ruma::JsOption { + self.room_state_get(room_id, &StateEventType::RoomAvatar, "") + .await + .map_or(ruma::JsOption::Undefined, |s| { serde_json::from_str(s.content.get()) .map_err(|_| Error::bad_database("Invalid room avatar event in database.")) + .unwrap() }) } - pub fn get_member(&self, room_id: &RoomId, user_id: &UserId) -> Result> { - self.room_state_get(room_id, &StateEventType::RoomMember, user_id.as_str())? - .map_or(Ok(None), |s| { - serde_json::from_str(s.content.get()) - .map_err(|_| Error::bad_database("Invalid room member event in database.")) - }) + pub async fn get_member(&self, room_id: &RoomId, user_id: &UserId) -> Result { + self.room_state_get_content(room_id, &StateEventType::RoomMember, user_id.as_str()) + .await } - pub fn user_can_invite( + pub async fn user_can_invite( &self, room_id: &RoomId, sender: &UserId, target_user: &UserId, state_lock: &RoomMutexGuard, - ) -> Result { + ) -> bool { let content = to_raw_value(&RoomMemberEventContent::new(MembershipState::Invite)) .expect("Event content always serializes"); @@ -353,122 +386,101 @@ impl Service { timestamp: None, }; - Ok(self - .services + self.services .timeline .create_hash_and_sign_event(new_event, sender, room_id, state_lock) - .is_ok()) + .await + .is_ok() } /// Checks if guests are able to view room content without joining - pub fn is_world_readable(&self, room_id: &RoomId) -> Result { - self.room_state_get(room_id, &StateEventType::RoomHistoryVisibility, "")? - .map_or(Ok(false), |s| { - serde_json::from_str(s.content.get()) - .map(|c: RoomHistoryVisibilityEventContent| { - c.history_visibility == HistoryVisibility::WorldReadable - }) - .map_err(|e| { - error!( - "Invalid room history visibility event in database for room {room_id}, assuming not world \ - readable: {e} " - ); - Error::bad_database("Invalid room history visibility event in database.") - }) - }) + pub async fn is_world_readable(&self, room_id: &RoomId) -> bool { + self.room_state_get_content(room_id, &StateEventType::RoomHistoryVisibility, "") + .await + .map(|c: RoomHistoryVisibilityEventContent| c.history_visibility == HistoryVisibility::WorldReadable) + .unwrap_or(false) } /// Checks if guests are able to join a given room - pub fn guest_can_join(&self, room_id: &RoomId) -> Result { - self.room_state_get(room_id, &StateEventType::RoomGuestAccess, "")? - .map_or(Ok(false), |s| { - serde_json::from_str(s.content.get()) - .map(|c: RoomGuestAccessEventContent| c.guest_access == GuestAccess::CanJoin) - .map_err(|_| Error::bad_database("Invalid room guest access event in database.")) - }) + pub async fn guest_can_join(&self, room_id: &RoomId) -> bool { + self.room_state_get_content(room_id, &StateEventType::RoomGuestAccess, "") + .await + .map(|c: RoomGuestAccessEventContent| c.guest_access == GuestAccess::CanJoin) + .unwrap_or(false) } /// Gets the primary alias from canonical alias event - pub fn get_canonical_alias(&self, room_id: &RoomId) -> Result, Error> { - self.room_state_get(room_id, &StateEventType::RoomCanonicalAlias, "")? - .map_or(Ok(None), |s| { - serde_json::from_str(s.content.get()) - .map(|c: RoomCanonicalAliasEventContent| c.alias) - .map_err(|_| Error::bad_database("Invalid canonical alias event in database.")) + pub async fn get_canonical_alias(&self, room_id: &RoomId) -> Result { + self.room_state_get_content(room_id, &StateEventType::RoomCanonicalAlias, "") + .await + .and_then(|c: RoomCanonicalAliasEventContent| { + c.alias + .ok_or_else(|| err!(Request(NotFound("No alias found in event content.")))) }) } /// Gets the room topic - pub fn get_room_topic(&self, room_id: &RoomId) -> Result, Error> { - self.room_state_get(room_id, &StateEventType::RoomTopic, "")? - .map_or(Ok(None), |s| { - serde_json::from_str(s.content.get()) - .map(|c: RoomTopicEventContent| Some(c.topic)) - .map_err(|e| { - error!("Invalid room topic event in database for room {room_id}: {e}"); - Error::bad_database("Invalid room topic event in database.") - }) - }) + pub async fn get_room_topic(&self, room_id: &RoomId) -> Result { + self.room_state_get_content(room_id, &StateEventType::RoomTopic, "") + .await + .map(|c: RoomTopicEventContent| c.topic) } /// Checks if a given user can redact a given event /// /// If federation is true, it allows redaction events from any user of the /// same server as the original event sender - pub fn user_can_redact( + pub async fn user_can_redact( &self, redacts: &EventId, sender: &UserId, room_id: &RoomId, federation: bool, ) -> Result { - self.room_state_get(room_id, &StateEventType::RoomPowerLevels, "")? - .map_or_else( - || { - // Falling back on m.room.create to judge power level - if let Some(pdu) = self.room_state_get(room_id, &StateEventType::RoomCreate, "")? { - Ok(pdu.sender == sender - || if let Ok(Some(pdu)) = self.services.timeline.get_pdu(redacts) { - pdu.sender == sender - } else { - false - }) + if let Ok(event) = self + .room_state_get(room_id, &StateEventType::RoomPowerLevels, "") + .await + { + let Ok(event) = serde_json::from_str(event.content.get()) + .map(|content: RoomPowerLevelsEventContent| content.into()) + .map(|event: RoomPowerLevels| event) + else { + return Ok(false); + }; + + Ok(event.user_can_redact_event_of_other(sender) + || event.user_can_redact_own_event(sender) + && if let Ok(pdu) = self.services.timeline.get_pdu(redacts).await { + if federation { + pdu.sender.server_name() == sender.server_name() + } else { + pdu.sender == sender + } + } else { + false + }) + } else { + // Falling back on m.room.create to judge power level + if let Ok(pdu) = self + .room_state_get(room_id, &StateEventType::RoomCreate, "") + .await + { + Ok(pdu.sender == sender + || if let Ok(pdu) = self.services.timeline.get_pdu(redacts).await { + pdu.sender == sender } else { - Err(Error::bad_database( - "No m.room.power_levels or m.room.create events in database for room", - )) - } - }, - |event| { - serde_json::from_str(event.content.get()) - .map(|content: RoomPowerLevelsEventContent| content.into()) - .map(|event: RoomPowerLevels| { - event.user_can_redact_event_of_other(sender) - || event.user_can_redact_own_event(sender) - && if let Ok(Some(pdu)) = self.services.timeline.get_pdu(redacts) { - if federation { - pdu.sender.server_name() == sender.server_name() - } else { - pdu.sender == sender - } - } else { - false - } - }) - .map_err(|_| Error::bad_database("Invalid m.room.power_levels event in database")) - }, - ) + false + }) + } else { + Err(Error::bad_database( + "No m.room.power_levels or m.room.create events in database for room", + )) + } + } } /// Returns the join rule (`SpaceRoomJoinRule`) for a given room - pub fn get_join_rule(&self, room_id: &RoomId) -> Result<(SpaceRoomJoinRule, Vec), Error> { - Ok(self - .room_state_get(room_id, &StateEventType::RoomJoinRules, "")? - .map(|s| { - serde_json::from_str(s.content.get()) - .map(|c: RoomJoinRulesEventContent| { - (c.join_rule.clone().into(), self.allowed_room_ids(c.join_rule)) - }) - .map_err(|e| err!(Database(error!("Invalid room join rule event in database: {e}")))) - }) - .transpose()? - .unwrap_or((SpaceRoomJoinRule::Invite, vec![]))) + pub async fn get_join_rule(&self, room_id: &RoomId) -> Result<(SpaceRoomJoinRule, Vec)> { + self.room_state_get_content(room_id, &StateEventType::RoomJoinRules, "") + .await + .map(|c: RoomJoinRulesEventContent| (c.join_rule.clone().into(), self.allowed_room_ids(c.join_rule))) + .or_else(|_| Ok((SpaceRoomJoinRule::Invite, vec![]))) } /// Returns an empty vec if not a restricted room @@ -487,25 +499,21 @@ impl Service { room_ids } - pub fn get_room_type(&self, room_id: &RoomId) -> Result> { - Ok(self - .room_state_get(room_id, &StateEventType::RoomCreate, "")? - .map(|s| { - serde_json::from_str::(s.content.get()) - .map_err(|e| err!(Database(error!("Invalid room create event in database: {e}")))) + pub async fn get_room_type(&self, room_id: &RoomId) -> Result { + self.room_state_get_content(room_id, &StateEventType::RoomCreate, "") + .await + .and_then(|content: RoomCreateEventContent| { + content + .room_type + .ok_or_else(|| err!(Request(NotFound("No type found in event content")))) }) - .transpose()? - .and_then(|e| e.room_type)) } /// Gets the room's encryption algorithm if `m.room.encryption` state event /// is found - pub fn get_room_encryption(&self, room_id: &RoomId) -> Result> { - self.room_state_get(room_id, &StateEventType::RoomEncryption, "")? - .map_or(Ok(None), |s| { - serde_json::from_str::(s.content.get()) - .map(|content| Some(content.algorithm)) - .map_err(|e| err!(Database(error!("Invalid room encryption event in database: {e}")))) - }) + pub async fn get_room_encryption(&self, room_id: &RoomId) -> Result { + self.room_state_get_content(room_id, &StateEventType::RoomEncryption, "") + .await + .map(|content: RoomEncryptionEventContent| content.algorithm) } } diff --git a/src/service/rooms/state_cache/data.rs b/src/service/rooms/state_cache/data.rs index 19c73ea1c..f3ccaf102 100644 --- a/src/service/rooms/state_cache/data.rs +++ b/src/service/rooms/state_cache/data.rs @@ -1,43 +1,42 @@ use std::{ - collections::{HashMap, HashSet}, + collections::HashMap, sync::{Arc, RwLock}, }; -use conduit::{utils, Error, Result}; -use database::Map; -use itertools::Itertools; +use conduit::{utils, utils::stream::TryIgnore, Error, Result}; +use database::{Deserialized, Interfix, Map}; +use futures::{Stream, StreamExt}; use ruma::{ events::{AnyStrippedStateEvent, AnySyncStateEvent}, serde::Raw, - OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, + OwnedRoomId, RoomId, UserId, }; -use crate::{appservice::RegistrationInfo, globals, users, Dep}; +use crate::{globals, Dep}; -type StrippedStateEventIter<'a> = Box>)>> + 'a>; -type AnySyncStateEventIter<'a> = Box>)>> + 'a>; type AppServiceInRoomCache = RwLock>>; +type StrippedStateEventItem = (OwnedRoomId, Vec>); +type SyncStateEventItem = (OwnedRoomId, Vec>); pub(super) struct Data { pub(super) appservice_in_room_cache: AppServiceInRoomCache, - roomid_invitedcount: Arc, - roomid_inviteviaservers: Arc, - roomid_joinedcount: Arc, - roomserverids: Arc, - roomuserid_invitecount: Arc, - roomuserid_joined: Arc, - roomuserid_leftcount: Arc, - roomuseroncejoinedids: Arc, - serverroomids: Arc, - userroomid_invitestate: Arc, - userroomid_joined: Arc, - userroomid_leftstate: Arc, + pub(super) roomid_invitedcount: Arc, + pub(super) roomid_inviteviaservers: Arc, + pub(super) roomid_joinedcount: Arc, + pub(super) roomserverids: Arc, + pub(super) roomuserid_invitecount: Arc, + pub(super) roomuserid_joined: Arc, + pub(super) roomuserid_leftcount: Arc, + pub(super) roomuseroncejoinedids: Arc, + pub(super) serverroomids: Arc, + pub(super) userroomid_invitestate: Arc, + pub(super) userroomid_joined: Arc, + pub(super) userroomid_leftstate: Arc, services: Services, } struct Services { globals: Dep, - users: Dep, } impl Data { @@ -59,19 +58,18 @@ impl Data { userroomid_leftstate: db["userroomid_leftstate"].clone(), services: Services { globals: args.depend::("globals"), - users: args.depend::("users"), }, } } - pub(super) fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { + pub(super) fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) { let mut userroom_id = user_id.as_bytes().to_vec(); userroom_id.push(0xFF); userroom_id.extend_from_slice(room_id.as_bytes()); - self.roomuseroncejoinedids.insert(&userroom_id, &[]) + self.roomuseroncejoinedids.insert(&userroom_id, &[]); } - pub(super) fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { + pub(super) fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) { let roomid = room_id.as_bytes().to_vec(); let mut roomuser_id = roomid.clone(); @@ -82,64 +80,17 @@ impl Data { userroom_id.push(0xFF); userroom_id.extend_from_slice(room_id.as_bytes()); - self.userroomid_joined.insert(&userroom_id, &[])?; - self.roomuserid_joined.insert(&roomuser_id, &[])?; - self.userroomid_invitestate.remove(&userroom_id)?; - self.roomuserid_invitecount.remove(&roomuser_id)?; - self.userroomid_leftstate.remove(&userroom_id)?; - self.roomuserid_leftcount.remove(&roomuser_id)?; + self.userroomid_joined.insert(&userroom_id, &[]); + self.roomuserid_joined.insert(&roomuser_id, &[]); + self.userroomid_invitestate.remove(&userroom_id); + self.roomuserid_invitecount.remove(&roomuser_id); + self.userroomid_leftstate.remove(&userroom_id); + self.roomuserid_leftcount.remove(&roomuser_id); - self.roomid_inviteviaservers.remove(&roomid)?; - - Ok(()) + self.roomid_inviteviaservers.remove(&roomid); } - pub(super) fn mark_as_invited( - &self, user_id: &UserId, room_id: &RoomId, last_state: Option>>, - invite_via: Option>, - ) -> Result<()> { - let mut roomuser_id = room_id.as_bytes().to_vec(); - roomuser_id.push(0xFF); - roomuser_id.extend_from_slice(user_id.as_bytes()); - - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xFF); - userroom_id.extend_from_slice(room_id.as_bytes()); - - self.userroomid_invitestate.insert( - &userroom_id, - &serde_json::to_vec(&last_state.unwrap_or_default()).expect("state to bytes always works"), - )?; - self.roomuserid_invitecount - .insert(&roomuser_id, &self.services.globals.next_count()?.to_be_bytes())?; - self.userroomid_joined.remove(&userroom_id)?; - self.roomuserid_joined.remove(&roomuser_id)?; - self.userroomid_leftstate.remove(&userroom_id)?; - self.roomuserid_leftcount.remove(&roomuser_id)?; - - if let Some(servers) = invite_via { - let mut prev_servers = self - .servers_invite_via(room_id) - .filter_map(Result::ok) - .collect_vec(); - #[allow(clippy::redundant_clone)] // this is a necessary clone? - prev_servers.append(servers.clone().as_mut()); - let servers = prev_servers.iter().rev().unique().rev().collect_vec(); - - let servers = servers - .iter() - .map(|server| server.as_bytes()) - .collect_vec() - .join(&[0xFF][..]); - - self.roomid_inviteviaservers - .insert(room_id.as_bytes(), &servers)?; - } - - Ok(()) - } - - pub(super) fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { + pub(super) fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) { let roomid = room_id.as_bytes().to_vec(); let mut roomuser_id = roomid.clone(); @@ -153,115 +104,20 @@ impl Data { self.userroomid_leftstate.insert( &userroom_id, &serde_json::to_vec(&Vec::>::new()).unwrap(), - )?; // TODO + ); // TODO self.roomuserid_leftcount - .insert(&roomuser_id, &self.services.globals.next_count()?.to_be_bytes())?; - self.userroomid_joined.remove(&userroom_id)?; - self.roomuserid_joined.remove(&roomuser_id)?; - self.userroomid_invitestate.remove(&userroom_id)?; - self.roomuserid_invitecount.remove(&roomuser_id)?; - - self.roomid_inviteviaservers.remove(&roomid)?; - - Ok(()) - } - - pub(super) fn update_joined_count(&self, room_id: &RoomId) -> Result<()> { - let mut joinedcount = 0_u64; - let mut invitedcount = 0_u64; - let mut joined_servers = HashSet::new(); - - for joined in self.room_members(room_id).filter_map(Result::ok) { - joined_servers.insert(joined.server_name().to_owned()); - joinedcount = joinedcount.saturating_add(1); - } - - for _invited in self.room_members_invited(room_id).filter_map(Result::ok) { - invitedcount = invitedcount.saturating_add(1); - } - - self.roomid_joinedcount - .insert(room_id.as_bytes(), &joinedcount.to_be_bytes())?; - - self.roomid_invitedcount - .insert(room_id.as_bytes(), &invitedcount.to_be_bytes())?; - - for old_joined_server in self.room_servers(room_id).filter_map(Result::ok) { - if !joined_servers.remove(&old_joined_server) { - // Server not in room anymore - let mut roomserver_id = room_id.as_bytes().to_vec(); - roomserver_id.push(0xFF); - roomserver_id.extend_from_slice(old_joined_server.as_bytes()); - - let mut serverroom_id = old_joined_server.as_bytes().to_vec(); - serverroom_id.push(0xFF); - serverroom_id.extend_from_slice(room_id.as_bytes()); + .insert(&roomuser_id, &self.services.globals.next_count().unwrap().to_be_bytes()); + self.userroomid_joined.remove(&userroom_id); + self.roomuserid_joined.remove(&roomuser_id); + self.userroomid_invitestate.remove(&userroom_id); + self.roomuserid_invitecount.remove(&roomuser_id); - self.roomserverids.remove(&roomserver_id)?; - self.serverroomids.remove(&serverroom_id)?; - } - } - - // Now only new servers are in joined_servers anymore - for server in joined_servers { - let mut roomserver_id = room_id.as_bytes().to_vec(); - roomserver_id.push(0xFF); - roomserver_id.extend_from_slice(server.as_bytes()); - - let mut serverroom_id = server.as_bytes().to_vec(); - serverroom_id.push(0xFF); - serverroom_id.extend_from_slice(room_id.as_bytes()); - - self.roomserverids.insert(&roomserver_id, &[])?; - self.serverroomids.insert(&serverroom_id, &[])?; - } - - self.appservice_in_room_cache - .write() - .unwrap() - .remove(room_id); - - Ok(()) - } - - #[tracing::instrument(skip(self, room_id, appservice), level = "debug")] - pub(super) fn appservice_in_room(&self, room_id: &RoomId, appservice: &RegistrationInfo) -> Result { - let maybe = self - .appservice_in_room_cache - .read() - .unwrap() - .get(room_id) - .and_then(|map| map.get(&appservice.registration.id)) - .copied(); - - if let Some(b) = maybe { - Ok(b) - } else { - let bridge_user_id = UserId::parse_with_server_name( - appservice.registration.sender_localpart.as_str(), - self.services.globals.server_name(), - ) - .ok(); - - let in_room = bridge_user_id.map_or(false, |id| self.is_joined(&id, room_id).unwrap_or(false)) - || self - .room_members(room_id) - .any(|userid| userid.map_or(false, |userid| appservice.users.is_match(userid.as_str()))); - - self.appservice_in_room_cache - .write() - .unwrap() - .entry(room_id.to_owned()) - .or_default() - .insert(appservice.registration.id.clone(), in_room); - - Ok(in_room) - } + self.roomid_inviteviaservers.remove(&roomid); } /// Makes a user forget a room. #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn forget(&self, room_id: &RoomId, user_id: &UserId) -> Result<()> { + pub(super) fn forget(&self, room_id: &RoomId, user_id: &UserId) { let mut userroom_id = user_id.as_bytes().to_vec(); userroom_id.push(0xFF); userroom_id.extend_from_slice(room_id.as_bytes()); @@ -270,397 +126,63 @@ impl Data { roomuser_id.push(0xFF); roomuser_id.extend_from_slice(user_id.as_bytes()); - self.userroomid_leftstate.remove(&userroom_id)?; - self.roomuserid_leftcount.remove(&roomuser_id)?; - - Ok(()) - } - - /// Returns an iterator of all servers participating in this room. - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn room_servers<'a>( - &'a self, room_id: &RoomId, - ) -> Box> + 'a> { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xFF); - - Box::new(self.roomserverids.scan_prefix(prefix).map(|(key, _)| { - ServerName::parse( - utils::string_from_bytes( - key.rsplit(|&b| b == 0xFF) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| Error::bad_database("Server name in roomserverids is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("Server name in roomserverids is invalid.")) - })) - } - - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn server_in_room(&self, server: &ServerName, room_id: &RoomId) -> Result { - let mut key = server.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(room_id.as_bytes()); - - self.serverroomids.get(&key).map(|o| o.is_some()) - } - - /// Returns an iterator of all rooms a server participates in (as far as we - /// know). - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn server_rooms<'a>( - &'a self, server: &ServerName, - ) -> Box> + 'a> { - let mut prefix = server.as_bytes().to_vec(); - prefix.push(0xFF); - - Box::new(self.serverroomids.scan_prefix(prefix).map(|(key, _)| { - RoomId::parse( - utils::string_from_bytes( - key.rsplit(|&b| b == 0xFF) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| Error::bad_database("RoomId in serverroomids is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("RoomId in serverroomids is invalid.")) - })) - } - - /// Returns an iterator of all joined members of a room. - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn room_members<'a>( - &'a self, room_id: &RoomId, - ) -> Box> + Send + 'a> { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xFF); - - Box::new(self.roomuserid_joined.scan_prefix(prefix).map(|(key, _)| { - UserId::parse( - utils::string_from_bytes( - key.rsplit(|&b| b == 0xFF) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| Error::bad_database("User ID in roomuserid_joined is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("User ID in roomuserid_joined is invalid.")) - })) - } - - /// Returns an iterator of all our local users in the room, even if they're - /// deactivated/guests - pub(super) fn local_users_in_room<'a>(&'a self, room_id: &RoomId) -> Box + 'a> { - Box::new( - self.room_members(room_id) - .filter_map(Result::ok) - .filter(|user| self.services.globals.user_is_local(user)), - ) - } - - /// Returns an iterator of all our local joined users in a room who are - /// active (not deactivated, not guest) - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn active_local_users_in_room<'a>( - &'a self, room_id: &RoomId, - ) -> Box + 'a> { - Box::new( - self.local_users_in_room(room_id) - .filter(|user| !self.services.users.is_deactivated(user).unwrap_or(true)), - ) - } - - /// Returns the number of users which are currently in a room - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn room_joined_count(&self, room_id: &RoomId) -> Result> { - self.roomid_joinedcount - .get(room_id.as_bytes())? - .map(|b| utils::u64_from_bytes(&b).map_err(|_| Error::bad_database("Invalid joinedcount in db."))) - .transpose() - } - - /// Returns the number of users which are currently invited to a room - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn room_invited_count(&self, room_id: &RoomId) -> Result> { - self.roomid_invitedcount - .get(room_id.as_bytes())? - .map(|b| utils::u64_from_bytes(&b).map_err(|_| Error::bad_database("Invalid joinedcount in db."))) - .transpose() - } - - /// Returns an iterator over all User IDs who ever joined a room. - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn room_useroncejoined<'a>( - &'a self, room_id: &RoomId, - ) -> Box> + 'a> { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xFF); - - Box::new( - self.roomuseroncejoinedids - .scan_prefix(prefix) - .map(|(key, _)| { - UserId::parse( - utils::string_from_bytes( - key.rsplit(|&b| b == 0xFF) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| Error::bad_database("User ID in room_useroncejoined is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("User ID in room_useroncejoined is invalid.")) - }), - ) - } - - /// Returns an iterator over all invited members of a room. - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn room_members_invited<'a>( - &'a self, room_id: &RoomId, - ) -> Box> + 'a> { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xFF); - - Box::new( - self.roomuserid_invitecount - .scan_prefix(prefix) - .map(|(key, _)| { - UserId::parse( - utils::string_from_bytes( - key.rsplit(|&b| b == 0xFF) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| Error::bad_database("User ID in roomuserid_invited is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("User ID in roomuserid_invited is invalid.")) - }), - ) - } - - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn get_invite_count(&self, room_id: &RoomId, user_id: &UserId) -> Result> { - let mut key = room_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(user_id.as_bytes()); - - self.roomuserid_invitecount - .get(&key)? - .map_or(Ok(None), |bytes| { - Ok(Some( - utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid invitecount in db."))?, - )) - }) - } - - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result> { - let mut key = room_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(user_id.as_bytes()); - - self.roomuserid_leftcount - .get(&key)? - .map(|bytes| utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid leftcount in db."))) - .transpose() - } - - /// Returns an iterator over all rooms this user joined. - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn rooms_joined(&self, user_id: &UserId) -> Box> + '_> { - Box::new( - self.userroomid_joined - .scan_prefix(user_id.as_bytes().to_vec()) - .map(|(key, _)| { - RoomId::parse( - utils::string_from_bytes( - key.rsplit(|&b| b == 0xFF) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| Error::bad_database("Room ID in userroomid_joined is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("Room ID in userroomid_joined is invalid.")) - }), - ) + self.userroomid_leftstate.remove(&userroom_id); + self.roomuserid_leftcount.remove(&roomuser_id); } /// Returns an iterator over all rooms a user was invited to. - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn rooms_invited<'a>(&'a self, user_id: &UserId) -> StrippedStateEventIter<'a> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - - Box::new( - self.userroomid_invitestate - .scan_prefix(prefix) - .map(|(key, state)| { - let room_id = RoomId::parse( - utils::string_from_bytes( - key.rsplit(|&b| b == 0xFF) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid."))?; - - let state = serde_json::from_slice(&state) - .map_err(|_| Error::bad_database("Invalid state in userroomid_invitestate."))?; - - Ok((room_id, state)) - }), - ) - } - - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn invite_state( - &self, user_id: &UserId, room_id: &RoomId, - ) -> Result>>> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(room_id.as_bytes()); - + #[inline] + pub(super) fn rooms_invited<'a>( + &'a self, user_id: &'a UserId, + ) -> impl Stream + Send + 'a { + let prefix = (user_id, Interfix); self.userroomid_invitestate - .get(&key)? - .map(|state| { - let state = serde_json::from_slice(&state) - .map_err(|_| Error::bad_database("Invalid state in userroomid_invitestate."))?; - - Ok(state) + .stream_raw_prefix(&prefix) + .ignore_err() + .map(|(key, val)| { + let room_id = key.rsplit(|&b| b == 0xFF).next().unwrap(); + let room_id = utils::string_from_bytes(room_id).unwrap(); + let room_id = RoomId::parse(room_id).unwrap(); + let state = serde_json::from_slice(val) + .map_err(|_| Error::bad_database("Invalid state in userroomid_invitestate.")) + .unwrap(); + + (room_id, state) }) - .transpose() } #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn left_state( + pub(super) async fn invite_state( &self, user_id: &UserId, room_id: &RoomId, - ) -> Result>>> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(room_id.as_bytes()); - - self.userroomid_leftstate - .get(&key)? - .map(|state| { - let state = serde_json::from_slice(&state) - .map_err(|_| Error::bad_database("Invalid state in userroomid_leftstate."))?; - - Ok(state) - }) - .transpose() - } - - /// Returns an iterator over all rooms a user left. - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn rooms_left<'a>(&'a self, user_id: &UserId) -> AnySyncStateEventIter<'a> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - - Box::new( - self.userroomid_leftstate - .scan_prefix(prefix) - .map(|(key, state)| { - let room_id = RoomId::parse( - utils::string_from_bytes( - key.rsplit(|&b| b == 0xFF) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid."))?; - - let state = serde_json::from_slice(&state) - .map_err(|_| Error::bad_database("Invalid state in userroomid_leftstate."))?; - - Ok((room_id, state)) - }), - ) + ) -> Result>> { + let key = (user_id, room_id); + self.userroomid_invitestate.qry(&key).await.deserialized() } #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xFF); - userroom_id.extend_from_slice(room_id.as_bytes()); - - Ok(self.roomuseroncejoinedids.get(&userroom_id)?.is_some()) - } - - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn is_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xFF); - userroom_id.extend_from_slice(room_id.as_bytes()); - - Ok(self.userroomid_joined.get(&userroom_id)?.is_some()) - } - - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> Result { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xFF); - userroom_id.extend_from_slice(room_id.as_bytes()); - - Ok(self.userroomid_invitestate.get(&userroom_id)?.is_some()) - } - - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xFF); - userroom_id.extend_from_slice(room_id.as_bytes()); - - Ok(self.userroomid_leftstate.get(&userroom_id)?.is_some()) - } - - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn servers_invite_via<'a>( - &'a self, room_id: &RoomId, - ) -> Box> + 'a> { - let key = room_id.as_bytes().to_vec(); - - Box::new( - self.roomid_inviteviaservers - .scan_prefix(key) - .map(|(_, servers)| { - ServerName::parse( - utils::string_from_bytes( - servers - .rsplit(|&b| b == 0xFF) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| { - Error::bad_database("Server name in roomid_inviteviaservers is invalid unicode.") - })?, - ) - .map_err(|_| Error::bad_database("Server name in roomid_inviteviaservers is invalid.")) - }), - ) + pub(super) async fn left_state( + &self, user_id: &UserId, room_id: &RoomId, + ) -> Result>> { + let key = (user_id, room_id); + self.userroomid_leftstate.qry(&key).await.deserialized() } - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn add_servers_invite_via(&self, room_id: &RoomId, servers: &[OwnedServerName]) -> Result<()> { - let mut prev_servers = self - .servers_invite_via(room_id) - .filter_map(Result::ok) - .collect_vec(); - prev_servers.extend(servers.to_owned()); - prev_servers.sort_unstable(); - prev_servers.dedup(); - - let servers = prev_servers - .iter() - .map(|server| server.as_bytes()) - .collect_vec() - .join(&[0xFF][..]); - - self.roomid_inviteviaservers - .insert(room_id.as_bytes(), &servers)?; - - Ok(()) + /// Returns an iterator over all rooms a user left. + #[inline] + pub(super) fn rooms_left<'a>(&'a self, user_id: &'a UserId) -> impl Stream + Send + 'a { + let prefix = (user_id, Interfix); + self.userroomid_leftstate + .stream_raw_prefix(&prefix) + .ignore_err() + .map(|(key, val)| { + let room_id = key.rsplit(|&b| b == 0xFF).next().unwrap(); + let room_id = utils::string_from_bytes(room_id).unwrap(); + let room_id = RoomId::parse(room_id).unwrap(); + let state = serde_json::from_slice(val) + .map_err(|_| Error::bad_database("Invalid state in userroomid_leftstate.")) + .unwrap(); + + (room_id, state) + }) } } diff --git a/src/service/rooms/state_cache/mod.rs b/src/service/rooms/state_cache/mod.rs index 71899ceb9..dbe385619 100644 --- a/src/service/rooms/state_cache/mod.rs +++ b/src/service/rooms/state_cache/mod.rs @@ -1,9 +1,15 @@ mod data; -use std::sync::Arc; +use std::{collections::HashSet, sync::Arc}; -use conduit::{err, error, warn, Error, Result}; +use conduit::{ + err, + utils::{stream::TryIgnore, ReadyExt, StreamTools}, + warn, Result, +}; use data::Data; +use database::{Deserialized, Ignore, Interfix}; +use futures::{Stream, StreamExt}; use itertools::Itertools; use ruma::{ events::{ @@ -18,7 +24,7 @@ use ruma::{ }, int, serde::Raw, - OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, + OwnedRoomId, OwnedServerName, RoomId, ServerName, UserId, }; use crate::{account_data, appservice::RegistrationInfo, globals, rooms, users, Dep}; @@ -55,7 +61,7 @@ impl Service { /// Update current membership data. #[tracing::instrument(skip(self, last_state))] #[allow(clippy::too_many_arguments)] - pub fn update_membership( + pub async fn update_membership( &self, room_id: &RoomId, user_id: &UserId, membership_event: RoomMemberEventContent, sender: &UserId, last_state: Option>>, invite_via: Option>, update_joined_count: bool, @@ -68,7 +74,7 @@ impl Service { // update #[allow(clippy::collapsible_if)] if !self.services.globals.user_is_local(user_id) { - if !self.services.users.exists(user_id)? { + if !self.services.users.exists(user_id).await { self.services.users.create(user_id, None)?; } @@ -100,17 +106,17 @@ impl Service { match &membership { MembershipState::Join => { // Check if the user never joined this room - if !self.once_joined(user_id, room_id)? { + if !self.once_joined(user_id, room_id).await { // Add the user ID to the join list then - self.db.mark_as_once_joined(user_id, room_id)?; + self.db.mark_as_once_joined(user_id, room_id); // Check if the room has a predecessor - if let Some(predecessor) = self + if let Ok(Some(predecessor)) = self .services .state_accessor - .room_state_get(room_id, &StateEventType::RoomCreate, "")? - .and_then(|create| serde_json::from_str(create.content.get()).ok()) - .and_then(|content: RoomCreateEventContent| content.predecessor) + .room_state_get_content(room_id, &StateEventType::RoomCreate, "") + .await + .map(|content: RoomCreateEventContent| content.predecessor) { // Copy user settings from predecessor to the current room: // - Push rules @@ -138,32 +144,33 @@ impl Service { // .ok(); // Copy old tags to new room - if let Some(tag_event) = self + if let Ok(tag_event) = self .services .account_data - .get(Some(&predecessor.room_id), user_id, RoomAccountDataEventType::Tag)? - .map(|event| { + .get(Some(&predecessor.room_id), user_id, RoomAccountDataEventType::Tag) + .await + .and_then(|event| { serde_json::from_str(event.get()) .map_err(|e| err!(Database(warn!("Invalid account data event in db: {e:?}")))) }) { self.services .account_data - .update(Some(room_id), user_id, RoomAccountDataEventType::Tag, &tag_event?) + .update(Some(room_id), user_id, RoomAccountDataEventType::Tag, &tag_event) + .await .ok(); }; // Copy direct chat flag - if let Some(direct_event) = self + if let Ok(mut direct_event) = self .services .account_data - .get(None, user_id, GlobalAccountDataEventType::Direct.to_string().into())? - .map(|event| { + .get(None, user_id, GlobalAccountDataEventType::Direct.to_string().into()) + .await + .and_then(|event| { serde_json::from_str::(event.get()) .map_err(|e| err!(Database(warn!("Invalid account data event in db: {e:?}")))) }) { - let mut direct_event = direct_event?; let mut room_ids_updated = false; - for room_ids in direct_event.content.0.values_mut() { if room_ids.iter().any(|r| r == &predecessor.room_id) { room_ids.push(room_id.to_owned()); @@ -172,18 +179,21 @@ impl Service { } if room_ids_updated { - self.services.account_data.update( - None, - user_id, - GlobalAccountDataEventType::Direct.to_string().into(), - &serde_json::to_value(&direct_event).expect("to json always works"), - )?; + self.services + .account_data + .update( + None, + user_id, + GlobalAccountDataEventType::Direct.to_string().into(), + &serde_json::to_value(&direct_event).expect("to json always works"), + ) + .await?; } }; } } - self.db.mark_as_joined(user_id, room_id)?; + self.db.mark_as_joined(user_id, room_id); }, MembershipState::Invite => { // We want to know if the sender is ignored by the receiver @@ -196,12 +206,12 @@ impl Service { GlobalAccountDataEventType::IgnoredUserList .to_string() .into(), - )? - .map(|event| { + ) + .await + .and_then(|event| { serde_json::from_str::(event.get()) .map_err(|e| err!(Database(warn!("Invalid account data event in db: {e:?}")))) }) - .transpose()? .map_or(false, |ignored| { ignored .content @@ -214,194 +224,284 @@ impl Service { return Ok(()); } - self.db - .mark_as_invited(user_id, room_id, last_state, invite_via)?; + self.mark_as_invited(user_id, room_id, last_state, invite_via) + .await; }, MembershipState::Leave | MembershipState::Ban => { - self.db.mark_as_left(user_id, room_id)?; + self.db.mark_as_left(user_id, room_id); }, _ => {}, } if update_joined_count { - self.update_joined_count(room_id)?; + self.update_joined_count(room_id).await; } Ok(()) } - #[tracing::instrument(skip(self, room_id), level = "debug")] - pub fn update_joined_count(&self, room_id: &RoomId) -> Result<()> { self.db.update_joined_count(room_id) } - #[tracing::instrument(skip(self, room_id, appservice), level = "debug")] - pub fn appservice_in_room(&self, room_id: &RoomId, appservice: &RegistrationInfo) -> Result { - self.db.appservice_in_room(room_id, appservice) + pub async fn appservice_in_room(&self, room_id: &RoomId, appservice: &RegistrationInfo) -> bool { + let maybe = self + .db + .appservice_in_room_cache + .read() + .unwrap() + .get(room_id) + .and_then(|map| map.get(&appservice.registration.id)) + .copied(); + + if let Some(b) = maybe { + b + } else { + let bridge_user_id = UserId::parse_with_server_name( + appservice.registration.sender_localpart.as_str(), + self.services.globals.server_name(), + ) + .ok(); + + let in_room = if let Some(id) = &bridge_user_id { + self.is_joined(id, room_id).await + } else { + false + }; + + let in_room = in_room + || self + .room_members(room_id) + .ready_any(|userid| appservice.users.is_match(userid.as_str())) + .await; + + self.db + .appservice_in_room_cache + .write() + .unwrap() + .entry(room_id.to_owned()) + .or_default() + .insert(appservice.registration.id.clone(), in_room); + + in_room + } } /// Direct DB function to directly mark a user as left. It is not /// recommended to use this directly. You most likely should use /// `update_membership` instead #[tracing::instrument(skip(self), level = "debug")] - pub fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { - self.db.mark_as_left(user_id, room_id) - } + pub fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) { self.db.mark_as_left(user_id, room_id); } /// Direct DB function to directly mark a user as joined. It is not /// recommended to use this directly. You most likely should use /// `update_membership` instead #[tracing::instrument(skip(self), level = "debug")] - pub fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { - self.db.mark_as_joined(user_id, room_id) - } + pub fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) { self.db.mark_as_joined(user_id, room_id); } /// Makes a user forget a room. #[tracing::instrument(skip(self), level = "debug")] - pub fn forget(&self, room_id: &RoomId, user_id: &UserId) -> Result<()> { self.db.forget(room_id, user_id) } + pub fn forget(&self, room_id: &RoomId, user_id: &UserId) { self.db.forget(room_id, user_id); } /// Returns an iterator of all servers participating in this room. #[tracing::instrument(skip(self), level = "debug")] - pub fn room_servers(&self, room_id: &RoomId) -> impl Iterator> + '_ { - self.db.room_servers(room_id) + pub fn room_servers<'a>(&'a self, room_id: &'a RoomId) -> impl Stream + Send + 'a { + let prefix = (room_id, Interfix); + self.db + .roomserverids + .keys_prefix(&prefix) + .ignore_err() + .map(|(_, server): (Ignore, &ServerName)| server) } #[tracing::instrument(skip(self), level = "debug")] - pub fn server_in_room(&self, server: &ServerName, room_id: &RoomId) -> Result { - self.db.server_in_room(server, room_id) + pub async fn server_in_room<'a>(&'a self, server: &'a ServerName, room_id: &'a RoomId) -> bool { + let key = (server, room_id); + self.db.serverroomids.qry(&key).await.is_ok() } /// Returns an iterator of all rooms a server participates in (as far as we /// know). #[tracing::instrument(skip(self), level = "debug")] - pub fn server_rooms(&self, server: &ServerName) -> impl Iterator> + '_ { - self.db.server_rooms(server) + pub fn server_rooms<'a>(&'a self, server: &'a ServerName) -> impl Stream + Send + 'a { + let prefix = (server, Interfix); + self.db + .serverroomids + .keys_prefix(&prefix) + .ignore_err() + .map(|(_, room_id): (Ignore, &RoomId)| room_id) } /// Returns true if server can see user by sharing at least one room. #[tracing::instrument(skip(self), level = "debug")] - pub fn server_sees_user(&self, server: &ServerName, user_id: &UserId) -> Result { - Ok(self - .server_rooms(server) - .filter_map(Result::ok) - .any(|room_id: OwnedRoomId| self.is_joined(user_id, &room_id).unwrap_or(false))) + pub async fn server_sees_user(&self, server: &ServerName, user_id: &UserId) -> bool { + self.server_rooms(server) + .any(|room_id| self.is_joined(user_id, room_id)) + .await } /// Returns true if user_a and user_b share at least one room. #[tracing::instrument(skip(self), level = "debug")] - pub fn user_sees_user(&self, user_a: &UserId, user_b: &UserId) -> Result { + pub async fn user_sees_user(&self, user_a: &UserId, user_b: &UserId) -> bool { // Minimize number of point-queries by iterating user with least nr rooms - let (a, b) = if self.rooms_joined(user_a).count() < self.rooms_joined(user_b).count() { + let (a, b) = if self.rooms_joined(user_a).count().await < self.rooms_joined(user_b).count().await { (user_a, user_b) } else { (user_b, user_a) }; - Ok(self - .rooms_joined(a) - .filter_map(Result::ok) - .any(|room_id| self.is_joined(b, &room_id).unwrap_or(false))) + self.rooms_joined(a) + .any(|room_id| self.is_joined(b, room_id)) + .await } - /// Returns an iterator over all joined members of a room. + /// Returns an iterator of all joined members of a room. #[tracing::instrument(skip(self), level = "debug")] - pub fn room_members(&self, room_id: &RoomId) -> impl Iterator> + Send + '_ { - self.db.room_members(room_id) + pub fn room_members<'a>(&'a self, room_id: &'a RoomId) -> impl Stream + Send + 'a { + let prefix = (room_id, Interfix); + self.db + .roomuserid_joined + .keys_prefix(&prefix) + .ignore_err() + .map(|(_, user_id): (Ignore, &UserId)| user_id) } /// Returns the number of users which are currently in a room #[tracing::instrument(skip(self), level = "debug")] - pub fn room_joined_count(&self, room_id: &RoomId) -> Result> { self.db.room_joined_count(room_id) } + pub async fn room_joined_count(&self, room_id: &RoomId) -> Result { + self.db.roomid_joinedcount.qry(room_id).await.deserialized() + } #[tracing::instrument(skip(self), level = "debug")] /// Returns an iterator of all our local users in the room, even if they're /// deactivated/guests - pub fn local_users_in_room<'a>(&'a self, room_id: &RoomId) -> impl Iterator + 'a { - self.db.local_users_in_room(room_id) + pub fn local_users_in_room<'a>(&'a self, room_id: &'a RoomId) -> impl Stream + Send + 'a { + self.room_members(room_id) + .ready_filter(|user| self.services.globals.user_is_local(user)) } #[tracing::instrument(skip(self), level = "debug")] /// Returns an iterator of all our local joined users in a room who are /// active (not deactivated, not guest) - pub fn active_local_users_in_room<'a>(&'a self, room_id: &RoomId) -> impl Iterator + 'a { - self.db.active_local_users_in_room(room_id) + pub fn active_local_users_in_room<'a>(&'a self, room_id: &'a RoomId) -> impl Stream + Send + 'a { + self.local_users_in_room(room_id) + .filter(|user| self.services.users.is_active(user)) } /// Returns the number of users which are currently invited to a room #[tracing::instrument(skip(self), level = "debug")] - pub fn room_invited_count(&self, room_id: &RoomId) -> Result> { self.db.room_invited_count(room_id) } + pub async fn room_invited_count(&self, room_id: &RoomId) -> Result { + self.db + .roomid_invitedcount + .qry(room_id) + .await + .deserialized() + } /// Returns an iterator over all User IDs who ever joined a room. #[tracing::instrument(skip(self), level = "debug")] - pub fn room_useroncejoined(&self, room_id: &RoomId) -> impl Iterator> + '_ { - self.db.room_useroncejoined(room_id) + pub fn room_useroncejoined<'a>(&'a self, room_id: &'a RoomId) -> impl Stream + Send + 'a { + let prefix = (room_id, Interfix); + self.db + .roomuseroncejoinedids + .keys_prefix(&prefix) + .ignore_err() + .map(|(_, user_id): (Ignore, &UserId)| user_id) } /// Returns an iterator over all invited members of a room. #[tracing::instrument(skip(self), level = "debug")] - pub fn room_members_invited(&self, room_id: &RoomId) -> impl Iterator> + '_ { - self.db.room_members_invited(room_id) + pub fn room_members_invited<'a>(&'a self, room_id: &'a RoomId) -> impl Stream + Send + 'a { + let prefix = (room_id, Interfix); + self.db + .roomuserid_invitecount + .keys_prefix(&prefix) + .ignore_err() + .map(|(_, user_id): (Ignore, &UserId)| user_id) } #[tracing::instrument(skip(self), level = "debug")] - pub fn get_invite_count(&self, room_id: &RoomId, user_id: &UserId) -> Result> { - self.db.get_invite_count(room_id, user_id) + pub async fn get_invite_count(&self, room_id: &RoomId, user_id: &UserId) -> Result { + let key = (room_id, user_id); + self.db + .roomuserid_invitecount + .qry(&key) + .await + .deserialized() } #[tracing::instrument(skip(self), level = "debug")] - pub fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result> { - self.db.get_left_count(room_id, user_id) + pub async fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result { + let key = (room_id, user_id); + self.db.roomuserid_leftcount.qry(&key).await.deserialized() } /// Returns an iterator over all rooms this user joined. #[tracing::instrument(skip(self), level = "debug")] - pub fn rooms_joined(&self, user_id: &UserId) -> impl Iterator> + '_ { - self.db.rooms_joined(user_id) + pub fn rooms_joined<'a>(&'a self, user_id: &'a UserId) -> impl Stream + Send + 'a { + self.db + .userroomid_joined + .keys_prefix_raw(user_id) + .ignore_err() + .map(|(_, room_id): (Ignore, &RoomId)| room_id) } /// Returns an iterator over all rooms a user was invited to. #[tracing::instrument(skip(self), level = "debug")] - pub fn rooms_invited( - &self, user_id: &UserId, - ) -> impl Iterator>)>> + '_ { + pub fn rooms_invited<'a>( + &'a self, user_id: &'a UserId, + ) -> impl Stream>)> + Send + 'a { self.db.rooms_invited(user_id) } #[tracing::instrument(skip(self), level = "debug")] - pub fn invite_state(&self, user_id: &UserId, room_id: &RoomId) -> Result>>> { - self.db.invite_state(user_id, room_id) + pub async fn invite_state(&self, user_id: &UserId, room_id: &RoomId) -> Result>> { + self.db.invite_state(user_id, room_id).await } #[tracing::instrument(skip(self), level = "debug")] - pub fn left_state(&self, user_id: &UserId, room_id: &RoomId) -> Result>>> { - self.db.left_state(user_id, room_id) + pub async fn left_state(&self, user_id: &UserId, room_id: &RoomId) -> Result>> { + self.db.left_state(user_id, room_id).await } /// Returns an iterator over all rooms a user left. #[tracing::instrument(skip(self), level = "debug")] - pub fn rooms_left( - &self, user_id: &UserId, - ) -> impl Iterator>)>> + '_ { + pub fn rooms_left<'a>( + &'a self, user_id: &'a UserId, + ) -> impl Stream>)> + Send + 'a { self.db.rooms_left(user_id) } #[tracing::instrument(skip(self), level = "debug")] - pub fn once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result { - self.db.once_joined(user_id, room_id) + pub async fn once_joined(&self, user_id: &UserId, room_id: &RoomId) -> bool { + let key = (user_id, room_id); + self.db.roomuseroncejoinedids.qry(&key).await.is_ok() } #[tracing::instrument(skip(self), level = "debug")] - pub fn is_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result { self.db.is_joined(user_id, room_id) } + pub async fn is_joined<'a>(&'a self, user_id: &'a UserId, room_id: &'a RoomId) -> bool { + let key = (user_id, room_id); + self.db.userroomid_joined.qry(&key).await.is_ok() + } #[tracing::instrument(skip(self), level = "debug")] - pub fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> Result { - self.db.is_invited(user_id, room_id) + pub async fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> bool { + let key = (user_id, room_id); + self.db.userroomid_invitestate.qry(&key).await.is_ok() } #[tracing::instrument(skip(self), level = "debug")] - pub fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result { self.db.is_left(user_id, room_id) } + pub async fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> bool { + let key = (user_id, room_id); + self.db.userroomid_leftstate.qry(&key).await.is_ok() + } #[tracing::instrument(skip(self), level = "debug")] - pub fn servers_invite_via(&self, room_id: &RoomId) -> impl Iterator> + '_ { - self.db.servers_invite_via(room_id) + pub fn servers_invite_via<'a>(&'a self, room_id: &'a RoomId) -> impl Stream + Send + 'a { + type KeyVal<'a> = (Ignore, Vec<&'a ServerName>); + + self.db + .roomid_inviteviaservers + .stream_prefix_raw(room_id) + .ignore_err() + .map(|(_, servers): KeyVal<'_>| *servers.last().expect("at least one server")) } /// Gets up to three servers that are likely to be in the room in the @@ -409,44 +509,32 @@ impl Service { /// /// See #[tracing::instrument(skip(self))] - pub fn servers_route_via(&self, room_id: &RoomId) -> Result> { + pub async fn servers_route_via(&self, room_id: &RoomId) -> Result> { let most_powerful_user_server = self .services .state_accessor - .room_state_get(room_id, &StateEventType::RoomPowerLevels, "")? - .map(|pdu| { - serde_json::from_str(pdu.content.get()).map(|conent: RoomPowerLevelsEventContent| { - conent - .users - .iter() - .max_by_key(|(_, power)| *power) - .and_then(|x| { - if x.1 >= &int!(50) { - Some(x) - } else { - None - } - }) - .map(|(user, _power)| user.server_name().to_owned()) - }) + .room_state_get_content(room_id, &StateEventType::RoomPowerLevels, "") + .await + .map(|content: RoomPowerLevelsEventContent| { + content + .users + .iter() + .max_by_key(|(_, power)| *power) + .and_then(|x| (x.1 >= &int!(50)).then_some(x)) + .map(|(user, _power)| user.server_name().to_owned()) }) - .transpose() - .map_err(|e| { - error!("Invalid power levels event content in database: {e}"); - Error::bad_database("Invalid power levels event content in database") - })? - .flatten(); + .map_err(|e| err!(Database(error!(?e, "Invalid power levels event content in database."))))?; let mut servers: Vec = self .room_members(room_id) - .filter_map(Result::ok) .counts_by(|user| user.server_name().to_owned()) - .iter() + .await + .into_iter() .sorted_by_key(|(_, users)| *users) - .map(|(server, _)| server.to_owned()) + .map(|(server, _)| server) .rev() .take(3) - .collect_vec(); + .collect(); if let Some(server) = most_powerful_user_server { servers.insert(0, server); @@ -468,4 +556,123 @@ impl Service { .expect("locked") .clear(); } + + pub async fn update_joined_count(&self, room_id: &RoomId) { + let mut joinedcount = 0_u64; + let mut invitedcount = 0_u64; + let mut joined_servers = HashSet::new(); + + self.room_members(room_id) + .ready_for_each(|joined| { + joined_servers.insert(joined.server_name().to_owned()); + joinedcount = joinedcount.saturating_add(1); + }) + .await; + + invitedcount = invitedcount.saturating_add( + self.room_members_invited(room_id) + .count() + .await + .try_into() + .unwrap_or(0), + ); + + self.db + .roomid_joinedcount + .insert(room_id.as_bytes(), &joinedcount.to_be_bytes()); + + self.db + .roomid_invitedcount + .insert(room_id.as_bytes(), &invitedcount.to_be_bytes()); + + self.room_servers(room_id) + .ready_for_each(|old_joined_server| { + if !joined_servers.remove(old_joined_server) { + // Server not in room anymore + let mut roomserver_id = room_id.as_bytes().to_vec(); + roomserver_id.push(0xFF); + roomserver_id.extend_from_slice(old_joined_server.as_bytes()); + + let mut serverroom_id = old_joined_server.as_bytes().to_vec(); + serverroom_id.push(0xFF); + serverroom_id.extend_from_slice(room_id.as_bytes()); + + self.db.roomserverids.remove(&roomserver_id); + self.db.serverroomids.remove(&serverroom_id); + } + }) + .await; + + // Now only new servers are in joined_servers anymore + for server in joined_servers { + let mut roomserver_id = room_id.as_bytes().to_vec(); + roomserver_id.push(0xFF); + roomserver_id.extend_from_slice(server.as_bytes()); + + let mut serverroom_id = server.as_bytes().to_vec(); + serverroom_id.push(0xFF); + serverroom_id.extend_from_slice(room_id.as_bytes()); + + self.db.roomserverids.insert(&roomserver_id, &[]); + self.db.serverroomids.insert(&serverroom_id, &[]); + } + + self.db + .appservice_in_room_cache + .write() + .unwrap() + .remove(room_id); + } + + pub async fn mark_as_invited( + &self, user_id: &UserId, room_id: &RoomId, last_state: Option>>, + invite_via: Option>, + ) { + let mut roomuser_id = room_id.as_bytes().to_vec(); + roomuser_id.push(0xFF); + roomuser_id.extend_from_slice(user_id.as_bytes()); + + let mut userroom_id = user_id.as_bytes().to_vec(); + userroom_id.push(0xFF); + userroom_id.extend_from_slice(room_id.as_bytes()); + + self.db.userroomid_invitestate.insert( + &userroom_id, + &serde_json::to_vec(&last_state.unwrap_or_default()).expect("state to bytes always works"), + ); + self.db + .roomuserid_invitecount + .insert(&roomuser_id, &self.services.globals.next_count().unwrap().to_be_bytes()); + self.db.userroomid_joined.remove(&userroom_id); + self.db.roomuserid_joined.remove(&roomuser_id); + self.db.userroomid_leftstate.remove(&userroom_id); + self.db.roomuserid_leftcount.remove(&roomuser_id); + + if let Some(servers) = invite_via.as_deref() { + self.add_servers_invite_via(room_id, servers).await; + } + } + + #[tracing::instrument(skip(self, servers), level = "debug")] + pub async fn add_servers_invite_via(&self, room_id: &RoomId, servers: &[OwnedServerName]) { + let mut prev_servers: Vec<_> = self + .servers_invite_via(room_id) + .map(ToOwned::to_owned) + .collect() + .await; + + prev_servers.extend(servers.to_owned()); + prev_servers.sort_unstable(); + prev_servers.dedup(); + + let servers = prev_servers + .iter() + .map(|server| server.as_bytes()) + .collect_vec() + .join(&[0xFF][..]); + + self.db + .roomid_inviteviaservers + .insert(room_id.as_bytes(), &servers); + } } diff --git a/src/service/rooms/state_compressor/data.rs b/src/service/rooms/state_compressor/data.rs index 337730019..cb0204705 100644 --- a/src/service/rooms/state_compressor/data.rs +++ b/src/service/rooms/state_compressor/data.rs @@ -1,6 +1,6 @@ use std::{collections::HashSet, mem::size_of, sync::Arc}; -use conduit::{checked, utils, Error, Result}; +use conduit::{err, expected, utils, Result}; use database::{Database, Map}; use super::CompressedStateEvent; @@ -22,11 +22,15 @@ impl Data { } } - pub(super) fn get_statediff(&self, shortstatehash: u64) -> Result { + pub(super) async fn get_statediff(&self, shortstatehash: u64) -> Result { + const BUFSIZE: usize = size_of::(); + let value = self .shortstatehash_statediff - .get(&shortstatehash.to_be_bytes())? - .ok_or_else(|| Error::bad_database("State hash does not exist"))?; + .aqry::(&shortstatehash) + .await + .map_err(|e| err!(Database("Failed to find StateDiff from short {shortstatehash:?}: {e}")))?; + let parent = utils::u64_from_bytes(&value[0..size_of::()]).expect("bytes have right length"); let parent = if parent != 0 { Some(parent) @@ -40,10 +44,10 @@ impl Data { let stride = size_of::(); let mut i = stride; - while let Some(v) = value.get(i..checked!(i + 2 * stride)?) { + while let Some(v) = value.get(i..expected!(i + 2 * stride)) { if add_mode && v.starts_with(&0_u64.to_be_bytes()) { add_mode = false; - i = checked!(i + stride)?; + i = expected!(i + stride); continue; } if add_mode { @@ -51,7 +55,7 @@ impl Data { } else { removed.insert(v.try_into().expect("we checked the size above")); } - i = checked!(i + 2 * stride)?; + i = expected!(i + 2 * stride); } Ok(StateDiff { @@ -61,7 +65,7 @@ impl Data { }) } - pub(super) fn save_statediff(&self, shortstatehash: u64, diff: &StateDiff) -> Result<()> { + pub(super) fn save_statediff(&self, shortstatehash: u64, diff: &StateDiff) { let mut value = diff.parent.unwrap_or(0).to_be_bytes().to_vec(); for new in diff.added.iter() { value.extend_from_slice(&new[..]); @@ -75,6 +79,6 @@ impl Data { } self.shortstatehash_statediff - .insert(&shortstatehash.to_be_bytes(), &value) + .insert(&shortstatehash.to_be_bytes(), &value); } } diff --git a/src/service/rooms/state_compressor/mod.rs b/src/service/rooms/state_compressor/mod.rs index 2550774e1..cd3f2f738 100644 --- a/src/service/rooms/state_compressor/mod.rs +++ b/src/service/rooms/state_compressor/mod.rs @@ -27,14 +27,12 @@ type StateInfoLruCache = Mutex< >, >; -type ShortStateInfoResult = Result< - Vec<( - u64, // sstatehash - Arc>, // full state - Arc>, // added - Arc>, // removed - )>, ->; +type ShortStateInfoResult = Vec<( + u64, // sstatehash + Arc>, // full state + Arc>, // added + Arc>, // removed +)>; type ParentStatesVec = Vec<( u64, // sstatehash @@ -43,7 +41,7 @@ type ParentStatesVec = Vec<( Arc>, // removed )>; -type HashSetCompressStateEvent = Result<(u64, Arc>, Arc>)>; +type HashSetCompressStateEvent = (u64, Arc>, Arc>); pub type CompressedStateEvent = [u8; 2 * size_of::()]; pub struct Service { @@ -86,12 +84,11 @@ impl crate::Service for Service { impl Service { /// Returns a stack with info on shortstatehash, full state, added diff and /// removed diff for the selected shortstatehash and each parent layer. - #[tracing::instrument(skip(self), level = "debug")] - pub fn load_shortstatehash_info(&self, shortstatehash: u64) -> ShortStateInfoResult { + pub async fn load_shortstatehash_info(&self, shortstatehash: u64) -> Result { if let Some(r) = self .stateinfo_cache .lock() - .unwrap() + .expect("locked") .get_mut(&shortstatehash) { return Ok(r.clone()); @@ -101,11 +98,11 @@ impl Service { parent, added, removed, - } = self.db.get_statediff(shortstatehash)?; + } = self.db.get_statediff(shortstatehash).await?; if let Some(parent) = parent { - let mut response = self.load_shortstatehash_info(parent)?; - let mut state = (*response.last().unwrap().1).clone(); + let mut response = Box::pin(self.load_shortstatehash_info(parent)).await?; + let mut state = (*response.last().expect("at least one response").1).clone(); state.extend(added.iter().copied()); let removed = (*removed).clone(); for r in &removed { @@ -116,7 +113,7 @@ impl Service { self.stateinfo_cache .lock() - .unwrap() + .expect("locked") .insert(shortstatehash, response.clone()); Ok(response) @@ -124,33 +121,42 @@ impl Service { let response = vec![(shortstatehash, added.clone(), added, removed)]; self.stateinfo_cache .lock() - .unwrap() + .expect("locked") .insert(shortstatehash, response.clone()); + Ok(response) } } - pub fn compress_state_event(&self, shortstatekey: u64, event_id: &EventId) -> Result { + pub async fn compress_state_event(&self, shortstatekey: u64, event_id: &EventId) -> CompressedStateEvent { let mut v = shortstatekey.to_be_bytes().to_vec(); v.extend_from_slice( &self .services .short - .get_or_create_shorteventid(event_id)? + .get_or_create_shorteventid(event_id) + .await .to_be_bytes(), ); - Ok(v.try_into().expect("we checked the size above")) + + v.try_into().expect("we checked the size above") } /// Returns shortstatekey, event id #[inline] - pub fn parse_compressed_state_event(&self, compressed_event: &CompressedStateEvent) -> Result<(u64, Arc)> { - Ok(( - utils::u64_from_bytes(&compressed_event[0..size_of::()]).expect("bytes have right length"), - self.services.short.get_eventid_from_short( - utils::u64_from_bytes(&compressed_event[size_of::()..]).expect("bytes have right length"), - )?, - )) + pub async fn parse_compressed_state_event( + &self, compressed_event: &CompressedStateEvent, + ) -> Result<(u64, Arc)> { + use utils::u64_from_u8; + + let shortstatekey = u64_from_u8(&compressed_event[0..size_of::()]); + let event_id = self + .services + .short + .get_eventid_from_short(u64_from_u8(&compressed_event[size_of::()..])) + .await?; + + Ok((shortstatekey, event_id)) } /// Creates a new shortstatehash that often is just a diff to an already @@ -227,7 +233,7 @@ impl Service { added: statediffnew, removed: statediffremoved, }, - )?; + ); return Ok(()); }; @@ -280,7 +286,7 @@ impl Service { added: statediffnew, removed: statediffremoved, }, - )?; + ); } Ok(()) @@ -288,10 +294,15 @@ impl Service { /// Returns the new shortstatehash, and the state diff from the previous /// room state - pub fn save_state( + pub async fn save_state( &self, room_id: &RoomId, new_state_ids_compressed: Arc>, - ) -> HashSetCompressStateEvent { - let previous_shortstatehash = self.services.state.get_room_shortstatehash(room_id)?; + ) -> Result { + let previous_shortstatehash = self + .services + .state + .get_room_shortstatehash(room_id) + .await + .ok(); let state_hash = utils::calculate_hash( &new_state_ids_compressed @@ -303,14 +314,18 @@ impl Service { let (new_shortstatehash, already_existed) = self .services .short - .get_or_create_shortstatehash(&state_hash)?; + .get_or_create_shortstatehash(&state_hash) + .await; if Some(new_shortstatehash) == previous_shortstatehash { return Ok((new_shortstatehash, Arc::new(HashSet::new()), Arc::new(HashSet::new()))); } - let states_parents = - previous_shortstatehash.map_or_else(|| Ok(Vec::new()), |p| self.load_shortstatehash_info(p))?; + let states_parents = if let Some(p) = previous_shortstatehash { + self.load_shortstatehash_info(p).await.unwrap_or_default() + } else { + ShortStateInfoResult::new() + }; let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() { let statediffnew: HashSet<_> = new_state_ids_compressed diff --git a/src/service/rooms/threads/data.rs b/src/service/rooms/threads/data.rs index fb279a007..f50b812ca 100644 --- a/src/service/rooms/threads/data.rs +++ b/src/service/rooms/threads/data.rs @@ -1,13 +1,18 @@ use std::{mem::size_of, sync::Arc}; -use conduit::{checked, utils, Error, PduEvent, Result}; -use database::Map; +use conduit::{ + checked, + result::LogErr, + utils, + utils::{stream::TryIgnore, ReadyExt}, + PduEvent, Result, +}; +use database::{Deserialized, Map}; +use futures::{Stream, StreamExt}; use ruma::{api::client::threads::get_threads::v1::IncludeThreads, OwnedUserId, RoomId, UserId}; use crate::{rooms, Dep}; -type PduEventIterResult<'a> = Result> + 'a>>; - pub(super) struct Data { threadid_userids: Arc, services: Services, @@ -30,38 +35,37 @@ impl Data { } } - pub(super) fn threads_until<'a>( + pub(super) async fn threads_until<'a>( &'a self, user_id: &'a UserId, room_id: &'a RoomId, until: u64, _include: &'a IncludeThreads, - ) -> PduEventIterResult<'a> { + ) -> Result + Send + 'a> { let prefix = self .services .short - .get_shortroomid(room_id)? - .expect("room exists") + .get_shortroomid(room_id) + .await? .to_be_bytes() .to_vec(); let mut current = prefix.clone(); current.extend_from_slice(&(checked!(until - 1)?).to_be_bytes()); - Ok(Box::new( - self.threadid_userids - .iter_from(¤t, true) - .take_while(move |(k, _)| k.starts_with(&prefix)) - .map(move |(pduid, _users)| { - let count = utils::u64_from_bytes(&pduid[(size_of::())..]) - .map_err(|_| Error::bad_database("Invalid pduid in threadid_userids."))?; - let mut pdu = self - .services - .timeline - .get_pdu_from_id(&pduid)? - .ok_or_else(|| Error::bad_database("Invalid pduid reference in threadid_userids"))?; - if pdu.sender != user_id { - pdu.remove_transaction_id()?; - } - Ok((count, pdu)) - }), - )) + let stream = self + .threadid_userids + .rev_raw_keys_from(¤t) + .ignore_err() + .ready_take_while(move |key| key.starts_with(&prefix)) + .map(|pduid| (utils::u64_from_u8(&pduid[(size_of::())..]), pduid)) + .filter_map(move |(count, pduid)| async move { + let mut pdu = self.services.timeline.get_pdu_from_id(pduid).await.ok()?; + + if pdu.sender != user_id { + pdu.remove_transaction_id().log_err().ok(); + } + + Some((count, pdu)) + }); + + Ok(stream) } pub(super) fn update_participants(&self, root_id: &[u8], participants: &[OwnedUserId]) -> Result<()> { @@ -71,28 +75,12 @@ impl Data { .collect::>() .join(&[0xFF][..]); - self.threadid_userids.insert(root_id, &users)?; + self.threadid_userids.insert(root_id, &users); Ok(()) } - pub(super) fn get_participants(&self, root_id: &[u8]) -> Result>> { - if let Some(users) = self.threadid_userids.get(root_id)? { - Ok(Some( - users - .split(|b| *b == 0xFF) - .map(|bytes| { - UserId::parse( - utils::string_from_bytes(bytes) - .map_err(|_| Error::bad_database("Invalid UserId bytes in threadid_userids."))?, - ) - .map_err(|_| Error::bad_database("Invalid UserId in threadid_userids.")) - }) - .filter_map(Result::ok) - .collect(), - )) - } else { - Ok(None) - } + pub(super) async fn get_participants(&self, root_id: &[u8]) -> Result> { + self.threadid_userids.qry(root_id).await.deserialized() } } diff --git a/src/service/rooms/threads/mod.rs b/src/service/rooms/threads/mod.rs index ae51cd0f9..2eafe5d52 100644 --- a/src/service/rooms/threads/mod.rs +++ b/src/service/rooms/threads/mod.rs @@ -2,12 +2,12 @@ mod data; use std::{collections::BTreeMap, sync::Arc}; -use conduit::{Error, PduEvent, Result}; +use conduit::{err, PduEvent, Result}; use data::Data; +use futures::Stream; use ruma::{ - api::client::{error::ErrorKind, threads::get_threads::v1::IncludeThreads}, - events::relation::BundledThread, - uint, CanonicalJsonValue, EventId, RoomId, UserId, + api::client::threads::get_threads::v1::IncludeThreads, events::relation::BundledThread, uint, CanonicalJsonValue, + EventId, RoomId, UserId, }; use serde_json::json; @@ -36,30 +36,35 @@ impl crate::Service for Service { } impl Service { - pub fn threads_until<'a>( + pub async fn threads_until<'a>( &'a self, user_id: &'a UserId, room_id: &'a RoomId, until: u64, include: &'a IncludeThreads, - ) -> Result> + 'a> { - self.db.threads_until(user_id, room_id, until, include) + ) -> Result + Send + 'a> { + self.db + .threads_until(user_id, room_id, until, include) + .await } - pub fn add_to_thread(&self, root_event_id: &EventId, pdu: &PduEvent) -> Result<()> { + pub async fn add_to_thread(&self, root_event_id: &EventId, pdu: &PduEvent) -> Result<()> { let root_id = self .services .timeline - .get_pdu_id(root_event_id)? - .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Invalid event id in thread message"))?; + .get_pdu_id(root_event_id) + .await + .map_err(|e| err!(Request(InvalidParam("Invalid event_id in thread message: {e:?}"))))?; let root_pdu = self .services .timeline - .get_pdu_from_id(&root_id)? - .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Thread root pdu not found"))?; + .get_pdu_from_id(&root_id) + .await + .map_err(|e| err!(Request(InvalidParam("Thread root not found: {e:?}"))))?; let mut root_pdu_json = self .services .timeline - .get_pdu_json_from_id(&root_id)? - .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Thread root pdu not found"))?; + .get_pdu_json_from_id(&root_id) + .await + .map_err(|e| err!(Request(InvalidParam("Thread root pdu not found: {e:?}"))))?; if let CanonicalJsonValue::Object(unsigned) = root_pdu_json .entry("unsigned".to_owned()) @@ -103,11 +108,12 @@ impl Service { self.services .timeline - .replace_pdu(&root_id, &root_pdu_json, &root_pdu)?; + .replace_pdu(&root_id, &root_pdu_json, &root_pdu) + .await?; } let mut users = Vec::new(); - if let Some(userids) = self.db.get_participants(&root_id)? { + if let Ok(userids) = self.db.get_participants(&root_id).await { users.extend_from_slice(&userids); } else { users.push(root_pdu.sender); diff --git a/src/service/rooms/timeline/data.rs b/src/service/rooms/timeline/data.rs index 2f0c8f258..cb85cf19c 100644 --- a/src/service/rooms/timeline/data.rs +++ b/src/service/rooms/timeline/data.rs @@ -1,12 +1,20 @@ use std::{ collections::{hash_map, HashMap}, mem::size_of, - sync::{Arc, Mutex}, + sync::Arc, }; -use conduit::{checked, error, utils, Error, PduCount, PduEvent, Result}; -use database::{Database, Map}; -use ruma::{api::client::error::ErrorKind, CanonicalJsonObject, EventId, OwnedRoomId, OwnedUserId, RoomId, UserId}; +use conduit::{ + err, expected, + result::{LogErr, NotFound}, + utils, + utils::{stream::TryIgnore, u64_from_u8, ReadyExt}, + Err, PduCount, PduEvent, Result, +}; +use database::{Database, Deserialized, KeyVal, Map}; +use futures::{FutureExt, Stream, StreamExt}; +use ruma::{CanonicalJsonObject, EventId, OwnedRoomId, OwnedUserId, RoomId, UserId}; +use tokio::sync::Mutex; use crate::{rooms, Dep}; @@ -25,8 +33,7 @@ struct Services { short: Dep, } -type PdusIterItem = Result<(PduCount, PduEvent)>; -type PdusIterator<'a> = Box + 'a>; +pub type PdusIterItem = (PduCount, PduEvent); type LastTimelineCountCache = Mutex>; impl Data { @@ -46,23 +53,20 @@ impl Data { } } - pub(super) fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result { + pub(super) async fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result { match self .lasttimelinecount_cache .lock() - .expect("locked") + .await .entry(room_id.to_owned()) { hash_map::Entry::Vacant(v) => { if let Some(last_count) = self - .pdus_until(sender_user, room_id, PduCount::max())? - .find_map(|r| { - // Filter out buggy events - if r.is_err() { - error!("Bad pdu in pdus_since: {:?}", r); - } - r.ok() - }) { + .pdus_until(sender_user, room_id, PduCount::max()) + .await? + .next() + .await + { Ok(*v.insert(last_count.0)) } else { Ok(PduCount::Normal(0)) @@ -73,232 +77,212 @@ impl Data { } /// Returns the `count` of this pdu's id. - pub(super) fn get_pdu_count(&self, event_id: &EventId) -> Result> { + pub(super) async fn get_pdu_count(&self, event_id: &EventId) -> Result { self.eventid_pduid - .get(event_id.as_bytes())? + .get(event_id) + .await .map(|pdu_id| pdu_count(&pdu_id)) - .transpose() } /// Returns the json of a pdu. - pub(super) fn get_pdu_json(&self, event_id: &EventId) -> Result> { - self.get_non_outlier_pdu_json(event_id)?.map_or_else( - || { - self.eventid_outlierpdu - .get(event_id.as_bytes())? - .map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))) - .transpose() - }, - |x| Ok(Some(x)), - ) + pub(super) async fn get_pdu_json(&self, event_id: &EventId) -> Result { + if let Ok(pdu) = self.get_non_outlier_pdu_json(event_id).await { + return Ok(pdu); + } + + self.eventid_outlierpdu.get(event_id).await.deserialized() } /// Returns the json of a pdu. - pub(super) fn get_non_outlier_pdu_json(&self, event_id: &EventId) -> Result> { - self.eventid_pduid - .get(event_id.as_bytes())? - .map(|pduid| { - self.pduid_pdu - .get(&pduid)? - .ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid.")) - }) - .transpose()? - .map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))) - .transpose() + pub(super) async fn get_non_outlier_pdu_json(&self, event_id: &EventId) -> Result { + let pduid = self.get_pdu_id(event_id).await?; + + self.pduid_pdu.get(&pduid).await.deserialized() } /// Returns the pdu's id. #[inline] - pub(super) fn get_pdu_id(&self, event_id: &EventId) -> Result>> { - self.eventid_pduid.get(event_id.as_bytes()) + pub(super) async fn get_pdu_id(&self, event_id: &EventId) -> Result> { + self.eventid_pduid.get(event_id).await } /// Returns the pdu directly from `eventid_pduid` only. - pub(super) fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result> { - self.eventid_pduid - .get(event_id.as_bytes())? - .map(|pduid| { - self.pduid_pdu - .get(&pduid)? - .ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid.")) - }) - .transpose()? - .map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))) - .transpose() + pub(super) async fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result { + let pduid = self.get_pdu_id(event_id).await?; + + self.pduid_pdu.get(&pduid).await.deserialized() + } + + /// Like get_non_outlier_pdu(), but without the expense of fetching and + /// parsing the PduEvent + pub(super) async fn non_outlier_pdu_exists(&self, event_id: &EventId) -> Result<()> { + let pduid = self.get_pdu_id(event_id).await?; + + self.pduid_pdu.get(&pduid).await?; + + Ok(()) } /// Returns the pdu. /// /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. - pub(super) fn get_pdu(&self, event_id: &EventId) -> Result>> { - if let Some(pdu) = self - .get_non_outlier_pdu(event_id)? - .map_or_else( - || { - self.eventid_outlierpdu - .get(event_id.as_bytes())? - .map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))) - .transpose() - }, - |x| Ok(Some(x)), - )? - .map(Arc::new) - { - Ok(Some(pdu)) - } else { - Ok(None) + pub(super) async fn get_pdu(&self, event_id: &EventId) -> Result> { + if let Ok(pdu) = self.get_non_outlier_pdu(event_id).await { + return Ok(Arc::new(pdu)); } + + self.eventid_outlierpdu + .get(event_id) + .await + .deserialized() + .map(Arc::new) + } + + /// Like get_non_outlier_pdu(), but without the expense of fetching and + /// parsing the PduEvent + pub(super) async fn outlier_pdu_exists(&self, event_id: &EventId) -> Result<()> { + self.eventid_outlierpdu.get(event_id).await?; + + Ok(()) + } + + /// Like get_pdu(), but without the expense of fetching and parsing the data + pub(super) async fn pdu_exists(&self, event_id: &EventId) -> bool { + let non_outlier = self.non_outlier_pdu_exists(event_id).map(|res| res.is_ok()); + let outlier = self.outlier_pdu_exists(event_id).map(|res| res.is_ok()); + + //TODO: parallelize + non_outlier.await || outlier.await } /// Returns the pdu. /// /// This does __NOT__ check the outliers `Tree`. - pub(super) fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result> { - self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| { - Ok(Some( - serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))?, - )) - }) + pub(super) async fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result { + self.pduid_pdu.get(pdu_id).await.deserialized() } /// Returns the pdu as a `BTreeMap`. - pub(super) fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result> { - self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| { - Ok(Some( - serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))?, - )) - }) + pub(super) async fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result { + self.pduid_pdu.get(pdu_id).await.deserialized() } - pub(super) fn append_pdu( - &self, pdu_id: &[u8], pdu: &PduEvent, json: &CanonicalJsonObject, count: u64, - ) -> Result<()> { + pub(super) async fn append_pdu(&self, pdu_id: &[u8], pdu: &PduEvent, json: &CanonicalJsonObject, count: u64) { self.pduid_pdu.insert( pdu_id, &serde_json::to_vec(json).expect("CanonicalJsonObject is always a valid"), - )?; + ); self.lasttimelinecount_cache .lock() - .expect("locked") + .await .insert(pdu.room_id.clone(), PduCount::Normal(count)); - self.eventid_pduid.insert(pdu.event_id.as_bytes(), pdu_id)?; - self.eventid_outlierpdu.remove(pdu.event_id.as_bytes())?; - - Ok(()) + self.eventid_pduid.insert(pdu.event_id.as_bytes(), pdu_id); + self.eventid_outlierpdu.remove(pdu.event_id.as_bytes()); } - pub(super) fn prepend_backfill_pdu( - &self, pdu_id: &[u8], event_id: &EventId, json: &CanonicalJsonObject, - ) -> Result<()> { + pub(super) fn prepend_backfill_pdu(&self, pdu_id: &[u8], event_id: &EventId, json: &CanonicalJsonObject) { self.pduid_pdu.insert( pdu_id, &serde_json::to_vec(json).expect("CanonicalJsonObject is always a valid"), - )?; - - self.eventid_pduid.insert(event_id.as_bytes(), pdu_id)?; - self.eventid_outlierpdu.remove(event_id.as_bytes())?; + ); - Ok(()) + self.eventid_pduid.insert(event_id.as_bytes(), pdu_id); + self.eventid_outlierpdu.remove(event_id.as_bytes()); } /// Removes a pdu and creates a new one with the same id. - pub(super) fn replace_pdu(&self, pdu_id: &[u8], pdu_json: &CanonicalJsonObject, _pdu: &PduEvent) -> Result<()> { - if self.pduid_pdu.get(pdu_id)?.is_some() { - self.pduid_pdu.insert( - pdu_id, - &serde_json::to_vec(pdu_json).expect("CanonicalJsonObject is always a valid"), - )?; - } else { - return Err(Error::BadRequest(ErrorKind::NotFound, "PDU does not exist.")); + pub(super) async fn replace_pdu( + &self, pdu_id: &[u8], pdu_json: &CanonicalJsonObject, _pdu: &PduEvent, + ) -> Result<()> { + if self.pduid_pdu.get(pdu_id).await.is_not_found() { + return Err!(Request(NotFound("PDU does not exist."))); } + let pdu = serde_json::to_vec(pdu_json)?; + self.pduid_pdu.insert(pdu_id, &pdu); + Ok(()) } /// Returns an iterator over all events and their tokens in a room that /// happened before the event with id `until` in reverse-chronological /// order. - pub(super) fn pdus_until(&self, user_id: &UserId, room_id: &RoomId, until: PduCount) -> Result> { - let (prefix, current) = self.count_to_id(room_id, until, 1, true)?; - - let user_id = user_id.to_owned(); - - Ok(Box::new( - self.pduid_pdu - .iter_from(¤t, true) - .take_while(move |(k, _)| k.starts_with(&prefix)) - .map(move |(pdu_id, v)| { - let mut pdu = serde_json::from_slice::(&v) - .map_err(|_| Error::bad_database("PDU in db is invalid."))?; - if pdu.sender != user_id { - pdu.remove_transaction_id()?; - } - pdu.add_age()?; - let count = pdu_count(&pdu_id)?; - Ok((count, pdu)) - }), - )) + pub(super) async fn pdus_until<'a>( + &'a self, user_id: &'a UserId, room_id: &'a RoomId, until: PduCount, + ) -> Result + Send + 'a> { + let (prefix, current) = self.count_to_id(room_id, until, 1, true).await?; + let stream = self + .pduid_pdu + .rev_raw_stream_from(¤t) + .ignore_err() + .ready_take_while(move |(key, _)| key.starts_with(&prefix)) + .map(move |item| Self::each_pdu(item, user_id)); + + Ok(stream) } - pub(super) fn pdus_after(&self, user_id: &UserId, room_id: &RoomId, from: PduCount) -> Result> { - let (prefix, current) = self.count_to_id(room_id, from, 1, false)?; - - let user_id = user_id.to_owned(); - - Ok(Box::new( - self.pduid_pdu - .iter_from(¤t, false) - .take_while(move |(k, _)| k.starts_with(&prefix)) - .map(move |(pdu_id, v)| { - let mut pdu = serde_json::from_slice::(&v) - .map_err(|_| Error::bad_database("PDU in db is invalid."))?; - if pdu.sender != user_id { - pdu.remove_transaction_id()?; - } - pdu.add_age()?; - let count = pdu_count(&pdu_id)?; - Ok((count, pdu)) - }), - )) + pub(super) async fn pdus_after<'a>( + &'a self, user_id: &'a UserId, room_id: &'a RoomId, from: PduCount, + ) -> Result + Send + 'a> { + let (prefix, current) = self.count_to_id(room_id, from, 1, false).await?; + let stream = self + .pduid_pdu + .raw_stream_from(¤t) + .ignore_err() + .ready_take_while(move |(key, _)| key.starts_with(&prefix)) + .map(move |item| Self::each_pdu(item, user_id)); + + Ok(stream) + } + + fn each_pdu((pdu_id, pdu): KeyVal<'_>, user_id: &UserId) -> PdusIterItem { + let mut pdu = + serde_json::from_slice::(pdu).expect("PduEvent in pduid_pdu database column is invalid JSON"); + + if pdu.sender != user_id { + pdu.remove_transaction_id().log_err().ok(); + } + + pdu.add_age().log_err().ok(); + let count = pdu_count(pdu_id); + + (count, pdu) } pub(super) fn increment_notification_counts( &self, room_id: &RoomId, notifies: Vec, highlights: Vec, - ) -> Result<()> { - let mut notifies_batch = Vec::new(); - let mut highlights_batch = Vec::new(); + ) { + let _cork = self.db.cork(); + for user in notifies { let mut userroom_id = user.as_bytes().to_vec(); userroom_id.push(0xFF); userroom_id.extend_from_slice(room_id.as_bytes()); - notifies_batch.push(userroom_id); + increment(&self.userroomid_notificationcount, &userroom_id); } + for user in highlights { let mut userroom_id = user.as_bytes().to_vec(); userroom_id.push(0xFF); userroom_id.extend_from_slice(room_id.as_bytes()); - highlights_batch.push(userroom_id); + increment(&self.userroomid_highlightcount, &userroom_id); } - - self.userroomid_notificationcount - .increment_batch(notifies_batch.iter().map(Vec::as_slice))?; - self.userroomid_highlightcount - .increment_batch(highlights_batch.iter().map(Vec::as_slice))?; - Ok(()) } - pub(super) fn count_to_id( + pub(super) async fn count_to_id( &self, room_id: &RoomId, count: PduCount, offset: u64, subtract: bool, ) -> Result<(Vec, Vec)> { let prefix = self .services .short - .get_shortroomid(room_id)? - .ok_or_else(|| Error::bad_database("Looked for bad shortroomid in timeline"))? + .get_shortroomid(room_id) + .await + .map_err(|e| err!(Request(NotFound("Room {room_id:?} not found: {e:?}"))))? .to_be_bytes() .to_vec(); + let mut pdu_id = prefix.clone(); // +1 so we don't send the base event let count_raw = match count { @@ -326,17 +310,23 @@ impl Data { } /// Returns the `count` of this pdu's id. -pub(super) fn pdu_count(pdu_id: &[u8]) -> Result { - let stride = size_of::(); +pub(super) fn pdu_count(pdu_id: &[u8]) -> PduCount { + const STRIDE: usize = size_of::(); + let pdu_id_len = pdu_id.len(); - let last_u64 = utils::u64_from_bytes(&pdu_id[checked!(pdu_id_len - stride)?..]) - .map_err(|_| Error::bad_database("PDU has invalid count bytes."))?; - let second_last_u64 = - utils::u64_from_bytes(&pdu_id[checked!(pdu_id_len - 2 * stride)?..checked!(pdu_id_len - stride)?]); + let last_u64 = u64_from_u8(&pdu_id[expected!(pdu_id_len - STRIDE)..]); + let second_last_u64 = u64_from_u8(&pdu_id[expected!(pdu_id_len - 2 * STRIDE)..expected!(pdu_id_len - STRIDE)]); - if matches!(second_last_u64, Ok(0)) { - Ok(PduCount::Backfilled(u64::MAX.saturating_sub(last_u64))) + if second_last_u64 == 0 { + PduCount::Backfilled(u64::MAX.saturating_sub(last_u64)) } else { - Ok(PduCount::Normal(last_u64)) + PduCount::Normal(last_u64) } } + +//TODO: this is an ABA +fn increment(db: &Arc, key: &[u8]) { + let old = db.get_blocking(key); + let new = utils::increment(old.ok().as_deref()); + db.insert(key, &new); +} diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index 4f2352f81..6a26a1d53 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -1,19 +1,20 @@ mod data; use std::{ + cmp, collections::{BTreeMap, HashSet}, fmt::Write, sync::Arc, }; use conduit::{ - debug, error, info, + debug, err, error, info, pdu::{EventHash, PduBuilder, PduCount, PduEvent}, utils, - utils::{MutexMap, MutexMapGuard}, - validated, warn, Error, Result, Server, + utils::{stream::TryIgnore, IterStream, MutexMap, MutexMapGuard, ReadyExt}, + validated, warn, Err, Error, Result, Server, }; -use itertools::Itertools; +use futures::{future, future::ready, Future, Stream, StreamExt, TryStreamExt}; use ruma::{ api::{client::error::ErrorKind, federation}, canonical_json::to_canonical_value, @@ -39,6 +40,7 @@ use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; use tokio::sync::RwLock; use self::data::Data; +pub use self::data::PdusIterItem; use crate::{ account_data, admin, appservice, appservice::NamespaceRegex, globals, pusher, rooms, rooms::state_compressor::CompressedStateEvent, sending, server_keys, Dep, @@ -129,6 +131,7 @@ impl crate::Service for Service { } fn memory_usage(&self, out: &mut dyn Write) -> Result<()> { + /* let lasttimelinecount_cache = self .db .lasttimelinecount_cache @@ -136,6 +139,7 @@ impl crate::Service for Service { .expect("locked") .len(); writeln!(out, "lasttimelinecount_cache: {lasttimelinecount_cache}")?; + */ let mutex_insert = self.mutex_insert.len(); writeln!(out, "insert_mutex: {mutex_insert}")?; @@ -144,11 +148,13 @@ impl crate::Service for Service { } fn clear_cache(&self) { + /* self.db .lasttimelinecount_cache .lock() .expect("locked") .clear(); + */ } fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } @@ -156,28 +162,32 @@ impl crate::Service for Service { impl Service { #[tracing::instrument(skip(self), level = "debug")] - pub fn first_pdu_in_room(&self, room_id: &RoomId) -> Result>> { - self.all_pdus(user_id!("@doesntmatter:conduit.rs"), room_id)? + pub async fn first_pdu_in_room(&self, room_id: &RoomId) -> Result> { + self.all_pdus(user_id!("@doesntmatter:conduit.rs"), room_id) + .await? .next() - .map(|o| o.map(|(_, p)| Arc::new(p))) - .transpose() + .await + .map(|(_, p)| Arc::new(p)) + .ok_or_else(|| err!(Request(NotFound("No PDU found in room")))) } #[tracing::instrument(skip(self), level = "debug")] - pub fn latest_pdu_in_room(&self, room_id: &RoomId) -> Result>> { - self.all_pdus(user_id!("@placeholder:conduwuit.placeholder"), room_id)? - .last() - .map(|o| o.map(|(_, p)| Arc::new(p))) - .transpose() + pub async fn latest_pdu_in_room(&self, room_id: &RoomId) -> Result> { + self.pdus_until(user_id!("@placeholder:conduwuit.placeholder"), room_id, PduCount::max()) + .await? + .next() + .await + .map(|(_, p)| Arc::new(p)) + .ok_or_else(|| err!(Request(NotFound("No PDU found in room")))) } #[tracing::instrument(skip(self), level = "debug")] - pub fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result { - self.db.last_timeline_count(sender_user, room_id) + pub async fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result { + self.db.last_timeline_count(sender_user, room_id).await } /// Returns the `count` of this pdu's id. - pub fn get_pdu_count(&self, event_id: &EventId) -> Result> { self.db.get_pdu_count(event_id) } + pub async fn get_pdu_count(&self, event_id: &EventId) -> Result { self.db.get_pdu_count(event_id).await } // TODO Is this the same as the function above? /* @@ -203,49 +213,56 @@ impl Service { */ /// Returns the json of a pdu. - pub fn get_pdu_json(&self, event_id: &EventId) -> Result> { - self.db.get_pdu_json(event_id) + pub async fn get_pdu_json(&self, event_id: &EventId) -> Result { + self.db.get_pdu_json(event_id).await } /// Returns the json of a pdu. #[inline] - pub fn get_non_outlier_pdu_json(&self, event_id: &EventId) -> Result> { - self.db.get_non_outlier_pdu_json(event_id) + pub async fn get_non_outlier_pdu_json(&self, event_id: &EventId) -> Result { + self.db.get_non_outlier_pdu_json(event_id).await } /// Returns the pdu's id. #[inline] - pub fn get_pdu_id(&self, event_id: &EventId) -> Result>> { - self.db.get_pdu_id(event_id) + pub async fn get_pdu_id(&self, event_id: &EventId) -> Result> { + self.db.get_pdu_id(event_id).await } /// Returns the pdu. /// /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. #[inline] - pub fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result> { - self.db.get_non_outlier_pdu(event_id) + pub async fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result { + self.db.get_non_outlier_pdu(event_id).await } /// Returns the pdu. /// /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. - pub fn get_pdu(&self, event_id: &EventId) -> Result>> { self.db.get_pdu(event_id) } + pub async fn get_pdu(&self, event_id: &EventId) -> Result> { self.db.get_pdu(event_id).await } + + /// Checks if pdu exists + /// + /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. + pub fn pdu_exists<'a>(&'a self, event_id: &'a EventId) -> impl Future + Send + 'a { + self.db.pdu_exists(event_id) + } /// Returns the pdu. /// /// This does __NOT__ check the outliers `Tree`. - pub fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result> { self.db.get_pdu_from_id(pdu_id) } + pub async fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result { self.db.get_pdu_from_id(pdu_id).await } /// Returns the pdu as a `BTreeMap`. - pub fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result> { - self.db.get_pdu_json_from_id(pdu_id) + pub async fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result { + self.db.get_pdu_json_from_id(pdu_id).await } /// Removes a pdu and creates a new one with the same id. #[tracing::instrument(skip(self), level = "debug")] - pub fn replace_pdu(&self, pdu_id: &[u8], pdu_json: &CanonicalJsonObject, pdu: &PduEvent) -> Result<()> { - self.db.replace_pdu(pdu_id, pdu_json, pdu) + pub async fn replace_pdu(&self, pdu_id: &[u8], pdu_json: &CanonicalJsonObject, pdu: &PduEvent) -> Result<()> { + self.db.replace_pdu(pdu_id, pdu_json, pdu).await } /// Creates a new persisted data unit and adds it to a room. @@ -268,8 +285,9 @@ impl Service { let shortroomid = self .services .short - .get_shortroomid(&pdu.room_id)? - .expect("room exists"); + .get_shortroomid(&pdu.room_id) + .await + .map_err(|_| err!(Database("Room does not exist")))?; // Make unsigned fields correct. This is not properly documented in the spec, // but state events need to have previous content in the unsigned field, so @@ -279,17 +297,17 @@ impl Service { .entry("unsigned".to_owned()) .or_insert_with(|| CanonicalJsonValue::Object(BTreeMap::default())) { - if let Some(shortstatehash) = self + if let Ok(shortstatehash) = self .services .state_accessor .pdu_shortstatehash(&pdu.event_id) - .unwrap() + .await { - if let Some(prev_state) = self + if let Ok(prev_state) = self .services .state_accessor .state_get(shortstatehash, &pdu.kind.to_string().into(), state_key) - .unwrap() + .await { unsigned.insert( "prev_content".to_owned(), @@ -318,10 +336,12 @@ impl Service { // We must keep track of all events that have been referenced. self.services .pdu_metadata - .mark_as_referenced(&pdu.room_id, &pdu.prev_events)?; + .mark_as_referenced(&pdu.room_id, &pdu.prev_events); + self.services .state - .set_forward_extremities(&pdu.room_id, leaves, state_lock)?; + .set_forward_extremities(&pdu.room_id, leaves, state_lock) + .await; let insert_lock = self.mutex_insert.lock(&pdu.room_id).await; @@ -330,17 +350,17 @@ impl Service { // appending fails self.services .read_receipt - .private_read_set(&pdu.room_id, &pdu.sender, count1)?; + .private_read_set(&pdu.room_id, &pdu.sender, count1); self.services .user - .reset_notification_counts(&pdu.sender, &pdu.room_id)?; + .reset_notification_counts(&pdu.sender, &pdu.room_id); - let count2 = self.services.globals.next_count()?; + let count2 = self.services.globals.next_count().unwrap(); let mut pdu_id = shortroomid.to_be_bytes().to_vec(); pdu_id.extend_from_slice(&count2.to_be_bytes()); // Insert pdu - self.db.append_pdu(&pdu_id, pdu, &pdu_json, count2)?; + self.db.append_pdu(&pdu_id, pdu, &pdu_json, count2).await; drop(insert_lock); @@ -348,12 +368,9 @@ impl Service { let power_levels: RoomPowerLevelsEventContent = self .services .state_accessor - .room_state_get(&pdu.room_id, &StateEventType::RoomPowerLevels, "")? - .map(|ev| { - serde_json::from_str(ev.content.get()) - .map_err(|_| Error::bad_database("invalid m.room.power_levels event")) - }) - .transpose()? + .room_state_get_content(&pdu.room_id, &StateEventType::RoomPowerLevels, "") + .await + .map_err(|_| err!(Database("invalid m.room.power_levels event"))) .unwrap_or_default(); let sync_pdu = pdu.to_sync_room_event(); @@ -365,7 +382,9 @@ impl Service { .services .state_cache .active_local_users_in_room(&pdu.room_id) - .collect_vec(); + .map(ToOwned::to_owned) + .collect::>() + .await; if pdu.kind == TimelineEventType::RoomMember { if let Some(state_key) = &pdu.state_key { @@ -386,23 +405,19 @@ impl Service { let rules_for_user = self .services .account_data - .get(None, user, GlobalAccountDataEventType::PushRules.to_string().into())? - .map(|event| { - serde_json::from_str::(event.get()).map_err(|e| { - warn!("Invalid push rules event in db for user ID {user}: {e}"); - Error::bad_database("Invalid push rules event in db.") - }) - }) - .transpose()? - .map_or_else(|| Ruleset::server_default(user), |ev: PushRulesEvent| ev.content.global); + .get(None, user, GlobalAccountDataEventType::PushRules.to_string().into()) + .await + .and_then(|event| serde_json::from_str::(event.get()).map_err(Into::into)) + .map_or_else(|_| Ruleset::server_default(user), |ev: PushRulesEvent| ev.content.global); let mut highlight = false; let mut notify = false; - for action in - self.services - .pusher - .get_actions(user, &rules_for_user, &power_levels, &sync_pdu, &pdu.room_id)? + for action in self + .services + .pusher + .get_actions(user, &rules_for_user, &power_levels, &sync_pdu, &pdu.room_id) + .await? { match action { Action::Notify => notify = true, @@ -421,31 +436,36 @@ impl Service { highlights.push(user.clone()); } - for push_key in self.services.pusher.get_pushkeys(user) { - self.services - .sending - .send_pdu_push(&pdu_id, user, push_key?)?; - } + self.services + .pusher + .get_pushkeys(user) + .ready_for_each(|push_key| { + self.services + .sending + .send_pdu_push(&pdu_id, user, push_key.to_owned()) + .expect("TODO: replace with future"); + }) + .await; } self.db - .increment_notification_counts(&pdu.room_id, notifies, highlights)?; + .increment_notification_counts(&pdu.room_id, notifies, highlights); match pdu.kind { TimelineEventType::RoomRedaction => { use RoomVersionId::*; - let room_version_id = self.services.state.get_room_version(&pdu.room_id)?; + let room_version_id = self.services.state.get_room_version(&pdu.room_id).await?; match room_version_id { V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => { if let Some(redact_id) = &pdu.redacts { - if self.services.state_accessor.user_can_redact( - redact_id, - &pdu.sender, - &pdu.room_id, - false, - )? { - self.redact_pdu(redact_id, pdu, shortroomid)?; + if self + .services + .state_accessor + .user_can_redact(redact_id, &pdu.sender, &pdu.room_id, false) + .await? + { + self.redact_pdu(redact_id, pdu, shortroomid).await?; } } }, @@ -457,13 +477,13 @@ impl Service { })?; if let Some(redact_id) = &content.redacts { - if self.services.state_accessor.user_can_redact( - redact_id, - &pdu.sender, - &pdu.room_id, - false, - )? { - self.redact_pdu(redact_id, pdu, shortroomid)?; + if self + .services + .state_accessor + .user_can_redact(redact_id, &pdu.sender, &pdu.room_id, false) + .await? + { + self.redact_pdu(redact_id, pdu, shortroomid).await?; } } }, @@ -492,7 +512,7 @@ impl Service { let invite_state = match content.membership { MembershipState::Invite => { - let state = self.services.state.calculate_invite_state(pdu)?; + let state = self.services.state.calculate_invite_state(pdu).await?; Some(state) }, _ => None, @@ -500,15 +520,18 @@ impl Service { // Update our membership info, we do this here incase a user is invited // and immediately leaves we need the DB to record the invite event for auth - self.services.state_cache.update_membership( - &pdu.room_id, - &target_user_id, - content, - &pdu.sender, - invite_state, - None, - true, - )?; + self.services + .state_cache + .update_membership( + &pdu.room_id, + &target_user_id, + content, + &pdu.sender, + invite_state, + None, + true, + ) + .await?; } }, TimelineEventType::RoomMessage => { @@ -516,9 +539,7 @@ impl Service { .map_err(|_| Error::bad_database("Invalid content in pdu."))?; if let Some(body) = content.body { - self.services - .search - .index_pdu(shortroomid, &pdu_id, &body)?; + self.services.search.index_pdu(shortroomid, &pdu_id, &body); if self.services.admin.is_admin_command(pdu, &body).await { self.services @@ -531,10 +552,10 @@ impl Service { } if let Ok(content) = serde_json::from_str::(pdu.content.get()) { - if let Some(related_pducount) = self.get_pdu_count(&content.relates_to.event_id)? { + if let Ok(related_pducount) = self.get_pdu_count(&content.relates_to.event_id).await { self.services .pdu_metadata - .add_relation(PduCount::Normal(count2), related_pducount)?; + .add_relation(PduCount::Normal(count2), related_pducount); } } @@ -545,14 +566,17 @@ impl Service { } => { // We need to do it again here, because replies don't have // event_id as a top level field - if let Some(related_pducount) = self.get_pdu_count(&in_reply_to.event_id)? { + if let Ok(related_pducount) = self.get_pdu_count(&in_reply_to.event_id).await { self.services .pdu_metadata - .add_relation(PduCount::Normal(count2), related_pducount)?; + .add_relation(PduCount::Normal(count2), related_pducount); } }, Relation::Thread(thread) => { - self.services.threads.add_to_thread(&thread.event_id, pdu)?; + self.services + .threads + .add_to_thread(&thread.event_id, pdu) + .await?; }, _ => {}, // TODO: Aggregate other types } @@ -562,7 +586,8 @@ impl Service { if self .services .state_cache - .appservice_in_room(&pdu.room_id, appservice)? + .appservice_in_room(&pdu.room_id, appservice) + .await { self.services .sending @@ -596,15 +621,14 @@ impl Service { .as_ref() .map_or(false, |state_key| users.is_match(state_key)) }; - let matching_aliases = |aliases: &NamespaceRegex| { + let matching_aliases = |aliases: NamespaceRegex| { self.services .alias .local_aliases_for_room(&pdu.room_id) - .filter_map(Result::ok) - .any(|room_alias| aliases.is_match(room_alias.as_str())) + .ready_any(move |room_alias| aliases.is_match(room_alias.as_str())) }; - if matching_aliases(&appservice.aliases) + if matching_aliases(appservice.aliases.clone()).await || appservice.rooms.is_match(pdu.room_id.as_str()) || matching_users(&appservice.users) { @@ -617,7 +641,7 @@ impl Service { Ok(pdu_id) } - pub fn create_hash_and_sign_event( + pub async fn create_hash_and_sign_event( &self, pdu_builder: PduBuilder, sender: &UserId, @@ -636,47 +660,59 @@ impl Service { let prev_events: Vec<_> = self .services .state - .get_forward_extremities(room_id)? - .into_iter() + .get_forward_extremities(room_id) .take(20) - .collect(); + .map(Arc::from) + .collect() + .await; // If there was no create event yet, assume we are creating a room - let room_version_id = self.services.state.get_room_version(room_id).or_else(|_| { - if event_type == TimelineEventType::RoomCreate { - let content = serde_json::from_str::(content.get()) - .expect("Invalid content in RoomCreate pdu."); - Ok(content.room_version) - } else { - Err(Error::InconsistentRoomState( - "non-create event for room of unknown version", - room_id.to_owned(), - )) - } - })?; + let room_version_id = self + .services + .state + .get_room_version(room_id) + .await + .or_else(|_| { + if event_type == TimelineEventType::RoomCreate { + let content = serde_json::from_str::(content.get()) + .expect("Invalid content in RoomCreate pdu."); + Ok(content.room_version) + } else { + Err(Error::InconsistentRoomState( + "non-create event for room of unknown version", + room_id.to_owned(), + )) + } + })?; let room_version = RoomVersion::new(&room_version_id).expect("room version is supported"); - let auth_events = - self.services - .state - .get_auth_events(room_id, &event_type, sender, state_key.as_deref(), &content)?; + let auth_events = self + .services + .state + .get_auth_events(room_id, &event_type, sender, state_key.as_deref(), &content) + .await?; // Our depth is the maximum depth of prev_events + 1 let depth = prev_events .iter() - .filter_map(|event_id| Some(self.get_pdu(event_id).ok()??.depth)) - .max() - .unwrap_or_else(|| uint!(0)) + .stream() + .map(Ok) + .and_then(|event_id| self.get_pdu(event_id)) + .and_then(|pdu| future::ok(pdu.depth)) + .ignore_err() + .ready_fold(uint!(0), cmp::max) + .await .saturating_add(uint!(1)); let mut unsigned = unsigned.unwrap_or_default(); if let Some(state_key) = &state_key { - if let Some(prev_pdu) = - self.services - .state_accessor - .room_state_get(room_id, &event_type.to_string().into(), state_key)? + if let Ok(prev_pdu) = self + .services + .state_accessor + .room_state_get(room_id, &event_type.to_string().into(), state_key) + .await { unsigned.insert( "prev_content".to_owned(), @@ -727,19 +763,22 @@ impl Service { signatures: None, }; + let auth_fetch = |k: &StateEventType, s: &str| { + let key = (k.clone(), s.to_owned()); + ready(auth_events.get(&key)) + }; + let auth_check = state_res::auth_check( &room_version, &pdu, - None::, // TODO: third_party_invite - |k, s| auth_events.get(&(k.clone(), s.to_owned())), + None, // TODO: third_party_invite + auth_fetch, ) - .map_err(|e| { - error!("Auth check failed: {:?}", e); - Error::BadRequest(ErrorKind::forbidden(), "Auth check failed.") - })?; + .await + .map_err(|e| err!(Request(Forbidden(warn!("Auth check failed: {e:?}")))))?; if !auth_check { - return Err(Error::BadRequest(ErrorKind::forbidden(), "Event is not authorized.")); + return Err!(Request(Forbidden("Event is not authorized."))); } // Hash and sign @@ -795,7 +834,8 @@ impl Service { let _shorteventid = self .services .short - .get_or_create_shorteventid(&pdu.event_id)?; + .get_or_create_shorteventid(&pdu.event_id) + .await; Ok((pdu, pdu_json)) } @@ -811,108 +851,117 @@ impl Service { room_id: &RoomId, state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex ) -> Result> { - let (pdu, pdu_json) = self.create_hash_and_sign_event(pdu_builder, sender, room_id, state_lock)?; - if let Some(admin_room) = self.services.admin.get_admin_room()? { - if admin_room == room_id { - match pdu.event_type() { - TimelineEventType::RoomEncryption => { - warn!("Encryption is not allowed in the admins room"); - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "Encryption is not allowed in the admins room", - )); - }, - TimelineEventType::RoomMember => { - let target = pdu - .state_key() - .filter(|v| v.starts_with('@')) - .unwrap_or(sender.as_str()); - let server_user = &self.services.globals.server_user.to_string(); - - let content = serde_json::from_str::(pdu.content.get()) - .map_err(|_| Error::bad_database("Invalid content in pdu"))?; - - if content.membership == MembershipState::Leave { - if target == server_user { - warn!("Server user cannot leave from admins room"); - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "Server user cannot leave from admins room.", - )); - } + let (pdu, pdu_json) = self + .create_hash_and_sign_event(pdu_builder, sender, room_id, state_lock) + .await?; - let count = self - .services - .state_cache - .room_members(room_id) - .filter_map(Result::ok) - .filter(|m| self.services.globals.server_is_ours(m.server_name()) && m != target) - .count(); - if count < 2 { - warn!("Last admin cannot leave from admins room"); - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "Last admin cannot leave from admins room.", - )); - } + if self.services.admin.is_admin_room(&pdu.room_id).await { + match pdu.event_type() { + TimelineEventType::RoomEncryption => { + warn!("Encryption is not allowed in the admins room"); + return Err(Error::BadRequest( + ErrorKind::forbidden(), + "Encryption is not allowed in the admins room", + )); + }, + TimelineEventType::RoomMember => { + let target = pdu + .state_key() + .filter(|v| v.starts_with('@')) + .unwrap_or(sender.as_str()); + let server_user = &self.services.globals.server_user.to_string(); + + let content = serde_json::from_str::(pdu.content.get()) + .map_err(|_| Error::bad_database("Invalid content in pdu"))?; + + if content.membership == MembershipState::Leave { + if target == server_user { + warn!("Server user cannot leave from admins room"); + return Err(Error::BadRequest( + ErrorKind::forbidden(), + "Server user cannot leave from admins room.", + )); } - if content.membership == MembershipState::Ban && pdu.state_key().is_some() { - if target == server_user { - warn!("Server user cannot be banned in admins room"); - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "Server user cannot be banned in admins room.", - )); - } + let count = self + .services + .state_cache + .room_members(&pdu.room_id) + .ready_filter(|user| self.services.globals.user_is_local(user)) + .ready_filter(|user| *user != target) + .boxed() + .count() + .await; + + if count < 2 { + warn!("Last admin cannot leave from admins room"); + return Err(Error::BadRequest( + ErrorKind::forbidden(), + "Last admin cannot leave from admins room.", + )); + } + } - let count = self - .services - .state_cache - .room_members(room_id) - .filter_map(Result::ok) - .filter(|m| self.services.globals.server_is_ours(m.server_name()) && m != target) - .count(); - if count < 2 { - warn!("Last admin cannot be banned in admins room"); - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "Last admin cannot be banned in admins room.", - )); - } + if content.membership == MembershipState::Ban && pdu.state_key().is_some() { + if target == server_user { + warn!("Server user cannot be banned in admins room"); + return Err(Error::BadRequest( + ErrorKind::forbidden(), + "Server user cannot be banned in admins room.", + )); } - }, - _ => {}, - } + + let count = self + .services + .state_cache + .room_members(&pdu.room_id) + .ready_filter(|user| self.services.globals.user_is_local(user)) + .ready_filter(|user| *user != target) + .boxed() + .count() + .await; + + if count < 2 { + warn!("Last admin cannot be banned in admins room"); + return Err(Error::BadRequest( + ErrorKind::forbidden(), + "Last admin cannot be banned in admins room.", + )); + } + } + }, + _ => {}, } } // If redaction event is not authorized, do not append it to the timeline if pdu.kind == TimelineEventType::RoomRedaction { use RoomVersionId::*; - match self.services.state.get_room_version(&pdu.room_id)? { + match self.services.state.get_room_version(&pdu.room_id).await? { V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => { if let Some(redact_id) = &pdu.redacts { if !self .services .state_accessor - .user_can_redact(redact_id, &pdu.sender, &pdu.room_id, false)? + .user_can_redact(redact_id, &pdu.sender, &pdu.room_id, false) + .await? { - return Err(Error::BadRequest(ErrorKind::forbidden(), "User cannot redact this event.")); + return Err!(Request(Forbidden("User cannot redact this event."))); } }; }, _ => { let content = serde_json::from_str::(pdu.content.get()) - .map_err(|_| Error::bad_database("Invalid content in redaction pdu."))?; + .map_err(|e| err!(Database("Invalid content in redaction pdu: {e:?}")))?; if let Some(redact_id) = &content.redacts { if !self .services .state_accessor - .user_can_redact(redact_id, &pdu.sender, &pdu.room_id, false)? + .user_can_redact(redact_id, &pdu.sender, &pdu.room_id, false) + .await? { - return Err(Error::BadRequest(ErrorKind::forbidden(), "User cannot redact this event.")); + return Err!(Request(Forbidden("User cannot redact this event."))); } } }, @@ -922,7 +971,7 @@ impl Service { // We append to state before appending the pdu, so we don't have a moment in // time with the pdu without it's state. This is okay because append_pdu can't // fail. - let statehashid = self.services.state.append_to_state(&pdu)?; + let statehashid = self.services.state.append_to_state(&pdu).await?; let pdu_id = self .append_pdu( @@ -939,14 +988,15 @@ impl Service { // in time where events in the current room state do not exist self.services .state - .set_room_state(room_id, statehashid, state_lock)?; + .set_room_state(&pdu.room_id, statehashid, state_lock); let mut servers: HashSet = self .services .state_cache - .room_servers(room_id) - .filter_map(Result::ok) - .collect(); + .room_servers(&pdu.room_id) + .map(ToOwned::to_owned) + .collect() + .await; // In case we are kicking or banning a user, we need to inform their server of // the change @@ -966,7 +1016,8 @@ impl Service { self.services .sending - .send_pdu_servers(servers.into_iter(), &pdu_id)?; + .send_pdu_servers(servers.iter().map(AsRef::as_ref).stream(), &pdu_id) + .await?; Ok(pdu.event_id) } @@ -988,15 +1039,19 @@ impl Service { // fail. self.services .state - .set_event_state(&pdu.event_id, &pdu.room_id, state_ids_compressed)?; + .set_event_state(&pdu.event_id, &pdu.room_id, state_ids_compressed) + .await?; if soft_fail { self.services .pdu_metadata - .mark_as_referenced(&pdu.room_id, &pdu.prev_events)?; + .mark_as_referenced(&pdu.room_id, &pdu.prev_events); + self.services .state - .set_forward_extremities(&pdu.room_id, new_room_leaves, state_lock)?; + .set_forward_extremities(&pdu.room_id, new_room_leaves, state_lock) + .await; + return Ok(None); } @@ -1009,71 +1064,71 @@ impl Service { /// Returns an iterator over all PDUs in a room. #[inline] - pub fn all_pdus<'a>( - &'a self, user_id: &UserId, room_id: &RoomId, - ) -> Result> + 'a> { - self.pdus_after(user_id, room_id, PduCount::min()) + pub async fn all_pdus<'a>( + &'a self, user_id: &'a UserId, room_id: &'a RoomId, + ) -> Result + Send + 'a> { + self.pdus_after(user_id, room_id, PduCount::min()).await } /// Returns an iterator over all events and their tokens in a room that /// happened before the event with id `until` in reverse-chronological /// order. #[tracing::instrument(skip(self), level = "debug")] - pub fn pdus_until<'a>( - &'a self, user_id: &UserId, room_id: &RoomId, until: PduCount, - ) -> Result> + 'a> { - self.db.pdus_until(user_id, room_id, until) + pub async fn pdus_until<'a>( + &'a self, user_id: &'a UserId, room_id: &'a RoomId, until: PduCount, + ) -> Result + Send + 'a> { + self.db.pdus_until(user_id, room_id, until).await } /// Returns an iterator over all events and their token in a room that /// happened after the event with id `from` in chronological order. #[tracing::instrument(skip(self), level = "debug")] - pub fn pdus_after<'a>( - &'a self, user_id: &UserId, room_id: &RoomId, from: PduCount, - ) -> Result> + 'a> { - self.db.pdus_after(user_id, room_id, from) + pub async fn pdus_after<'a>( + &'a self, user_id: &'a UserId, room_id: &'a RoomId, from: PduCount, + ) -> Result + Send + 'a> { + self.db.pdus_after(user_id, room_id, from).await } /// Replace a PDU with the redacted form. #[tracing::instrument(skip(self, reason))] - pub fn redact_pdu(&self, event_id: &EventId, reason: &PduEvent, shortroomid: u64) -> Result<()> { + pub async fn redact_pdu(&self, event_id: &EventId, reason: &PduEvent, shortroomid: u64) -> Result<()> { // TODO: Don't reserialize, keep original json - if let Some(pdu_id) = self.get_pdu_id(event_id)? { - let mut pdu = self - .get_pdu_from_id(&pdu_id)? - .ok_or_else(|| Error::bad_database("PDU ID points to invalid PDU."))?; + let Ok(pdu_id) = self.get_pdu_id(event_id).await else { + // If event does not exist, just noop + return Ok(()); + }; - if let Ok(content) = serde_json::from_str::(pdu.content.get()) { - if let Some(body) = content.body { - self.services - .search - .deindex_pdu(shortroomid, &pdu_id, &body)?; - } + let mut pdu = self + .get_pdu_from_id(&pdu_id) + .await + .map_err(|e| err!(Database(error!(?pdu_id, ?event_id, ?e, "PDU ID points to invalid PDU."))))?; + + if let Ok(content) = serde_json::from_str::(pdu.content.get()) { + if let Some(body) = content.body { + self.services + .search + .deindex_pdu(shortroomid, &pdu_id, &body); } + } - let room_version_id = self.services.state.get_room_version(&pdu.room_id)?; + let room_version_id = self.services.state.get_room_version(&pdu.room_id).await?; - pdu.redact(room_version_id, reason)?; + pdu.redact(room_version_id, reason)?; - self.replace_pdu( - &pdu_id, - &utils::to_canonical_object(&pdu).map_err(|e| { - error!("Failed to convert PDU to canonical JSON: {}", e); - Error::bad_database("Failed to convert PDU to canonical JSON.") - })?, - &pdu, - )?; - } - // If event does not exist, just noop - Ok(()) + let obj = utils::to_canonical_object(&pdu) + .map_err(|e| err!(Database(error!(?event_id, ?e, "Failed to convert PDU to canonical JSON"))))?; + + self.replace_pdu(&pdu_id, &obj, &pdu).await } #[tracing::instrument(skip(self))] pub async fn backfill_if_required(&self, room_id: &RoomId, from: PduCount) -> Result<()> { let first_pdu = self - .all_pdus(user_id!("@doesntmatter:conduit.rs"), room_id)? + .all_pdus(user_id!("@doesntmatter:conduit.rs"), room_id) + .await? .next() - .expect("Room is not empty")?; + .await + .expect("Room is not empty"); if first_pdu.0 < from { // No backfill required, there are still events between them @@ -1083,17 +1138,18 @@ impl Service { let power_levels: RoomPowerLevelsEventContent = self .services .state_accessor - .room_state_get(room_id, &StateEventType::RoomPowerLevels, "")? + .room_state_get(room_id, &StateEventType::RoomPowerLevels, "") + .await .map(|ev| { serde_json::from_str(ev.content.get()) .map_err(|_| Error::bad_database("invalid m.room.power_levels event")) + .unwrap() }) - .transpose()? .unwrap_or_default(); let room_mods = power_levels.users.iter().filter_map(|(user_id, level)| { if level > &power_levels.users_default && !self.services.globals.user_is_local(user_id) { - Some(user_id.server_name().to_owned()) + Some(user_id.server_name()) } else { None } @@ -1103,34 +1159,43 @@ impl Service { .services .alias .local_aliases_for_room(room_id) - .filter_map(|alias| { - alias - .ok() - .filter(|alias| !self.services.globals.server_is_ours(alias.server_name())) - .map(|alias| alias.server_name().to_owned()) + .ready_filter_map(|alias| { + self.services + .globals + .server_is_ours(alias.server_name()) + .then_some(alias.server_name()) }); - let servers = room_mods + let mut servers = room_mods + .stream() .chain(room_alias_servers) - .chain(self.services.server.config.trusted_servers.clone()) - .filter(|server_name| { - if self.services.globals.server_is_ours(server_name) { - return false; - } - + .map(ToOwned::to_owned) + .chain( + self.services + .server + .config + .trusted_servers + .iter() + .map(ToOwned::to_owned) + .stream(), + ) + .ready_filter(|server_name| !self.services.globals.server_is_ours(server_name)) + .filter_map(|server_name| async move { self.services .state_cache - .server_in_room(server_name, room_id) - .unwrap_or(false) - }); + .server_in_room(&server_name, room_id) + .await + .then_some(server_name) + }) + .boxed(); - for backfill_server in servers { + while let Some(ref backfill_server) = servers.next().await { info!("Asking {backfill_server} for backfill"); let response = self .services .sending .send_federation_request( - &backfill_server, + backfill_server, federation::backfill::get_backfill::v1::Request { room_id: room_id.to_owned(), v: vec![first_pdu.1.event_id.as_ref().to_owned()], @@ -1142,7 +1207,7 @@ impl Service { Ok(response) => { let pub_key_map = RwLock::new(BTreeMap::new()); for pdu in response.pdus { - if let Err(e) = self.backfill_pdu(&backfill_server, pdu, &pub_key_map).await { + if let Err(e) = self.backfill_pdu(backfill_server, pdu, &pub_key_map).await { warn!("Failed to add backfilled pdu in room {room_id}: {e}"); } } @@ -1163,7 +1228,7 @@ impl Service { &self, origin: &ServerName, pdu: Box, pub_key_map: &RwLock>>, ) -> Result<()> { - let (event_id, value, room_id) = self.services.event_handler.parse_incoming_pdu(&pdu)?; + let (event_id, value, room_id) = self.services.event_handler.parse_incoming_pdu(&pdu).await?; // Lock so we cannot backfill the same pdu twice at the same time let mutex_lock = self @@ -1174,7 +1239,7 @@ impl Service { .await; // Skip the PDU if we already have it as a timeline event - if let Some(pdu_id) = self.get_pdu_id(&event_id)? { + if let Ok(pdu_id) = self.get_pdu_id(&event_id).await { let pdu_id = pdu_id.to_vec(); debug!("We already know {event_id} at {pdu_id:?}"); return Ok(()); @@ -1190,36 +1255,38 @@ impl Service { .handle_incoming_pdu(origin, &room_id, &event_id, value, false, pub_key_map) .await?; - let value = self.get_pdu_json(&event_id)?.expect("We just created it"); - let pdu = self.get_pdu(&event_id)?.expect("We just created it"); + let value = self + .get_pdu_json(&event_id) + .await + .expect("We just created it"); + let pdu = self.get_pdu(&event_id).await.expect("We just created it"); let shortroomid = self .services .short - .get_shortroomid(&room_id)? + .get_shortroomid(&room_id) + .await .expect("room exists"); let insert_lock = self.mutex_insert.lock(&room_id).await; let max = u64::MAX; - let count = self.services.globals.next_count()?; + let count = self.services.globals.next_count().unwrap(); let mut pdu_id = shortroomid.to_be_bytes().to_vec(); pdu_id.extend_from_slice(&0_u64.to_be_bytes()); - pdu_id.extend_from_slice(&(validated!(max - count)?).to_be_bytes()); + pdu_id.extend_from_slice(&(validated!(max - count)).to_be_bytes()); // Insert pdu - self.db.prepend_backfill_pdu(&pdu_id, &event_id, &value)?; + self.db.prepend_backfill_pdu(&pdu_id, &event_id, &value); drop(insert_lock); if pdu.kind == TimelineEventType::RoomMessage { let content = serde_json::from_str::(pdu.content.get()) - .map_err(|_| Error::bad_database("Invalid content in pdu."))?; + .map_err(|e| err!(Database("Invalid content in pdu: {e:?}")))?; if let Some(body) = content.body { - self.services - .search - .index_pdu(shortroomid, &pdu_id, &body)?; + self.services.search.index_pdu(shortroomid, &pdu_id, &body); } } drop(mutex_lock); diff --git a/src/service/rooms/typing/mod.rs b/src/service/rooms/typing/mod.rs index 3cf1cdd59..bcfce6168 100644 --- a/src/service/rooms/typing/mod.rs +++ b/src/service/rooms/typing/mod.rs @@ -46,7 +46,7 @@ impl Service { /// Sets a user as typing until the timeout timestamp is reached or /// roomtyping_remove is called. pub async fn typing_add(&self, user_id: &UserId, room_id: &RoomId, timeout: u64) -> Result<()> { - debug_info!("typing started {:?} in {:?} timeout:{:?}", user_id, room_id, timeout); + debug_info!("typing started {user_id:?} in {room_id:?} timeout:{timeout:?}"); // update clients self.typing .write() @@ -54,17 +54,19 @@ impl Service { .entry(room_id.to_owned()) .or_default() .insert(user_id.to_owned(), timeout); + self.last_typing_update .write() .await .insert(room_id.to_owned(), self.services.globals.next_count()?); + if self.typing_update_sender.send(room_id.to_owned()).is_err() { trace!("receiver found what it was looking for and is no longer interested"); } // update federation if self.services.globals.user_is_local(user_id) { - self.federation_send(room_id, user_id, true)?; + self.federation_send(room_id, user_id, true).await?; } Ok(()) @@ -72,7 +74,7 @@ impl Service { /// Removes a user from typing before the timeout is reached. pub async fn typing_remove(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { - debug_info!("typing stopped {:?} in {:?}", user_id, room_id); + debug_info!("typing stopped {user_id:?} in {room_id:?}"); // update clients self.typing .write() @@ -80,31 +82,31 @@ impl Service { .entry(room_id.to_owned()) .or_default() .remove(user_id); + self.last_typing_update .write() .await .insert(room_id.to_owned(), self.services.globals.next_count()?); + if self.typing_update_sender.send(room_id.to_owned()).is_err() { trace!("receiver found what it was looking for and is no longer interested"); } // update federation if self.services.globals.user_is_local(user_id) { - self.federation_send(room_id, user_id, false)?; + self.federation_send(room_id, user_id, false).await?; } Ok(()) } - pub async fn wait_for_update(&self, room_id: &RoomId) -> Result<()> { + pub async fn wait_for_update(&self, room_id: &RoomId) { let mut receiver = self.typing_update_sender.subscribe(); while let Ok(next) = receiver.recv().await { if next == room_id { break; } } - - Ok(()) } /// Makes sure that typing events with old timestamps get removed. @@ -123,30 +125,30 @@ impl Service { removable.push(user.clone()); } } - - drop(typing); }; if !removable.is_empty() { let typing = &mut self.typing.write().await; let room = typing.entry(room_id.to_owned()).or_default(); for user in &removable { - debug_info!("typing timeout {:?} in {:?}", &user, room_id); + debug_info!("typing timeout {user:?} in {room_id:?}"); room.remove(user); } + // update clients self.last_typing_update .write() .await .insert(room_id.to_owned(), self.services.globals.next_count()?); + if self.typing_update_sender.send(room_id.to_owned()).is_err() { trace!("receiver found what it was looking for and is no longer interested"); } // update federation - for user in removable { - if self.services.globals.user_is_local(&user) { - self.federation_send(room_id, &user, false)?; + for user in &removable { + if self.services.globals.user_is_local(user) { + self.federation_send(room_id, user, false).await?; } } } @@ -183,7 +185,7 @@ impl Service { }) } - fn federation_send(&self, room_id: &RoomId, user_id: &UserId, typing: bool) -> Result<()> { + async fn federation_send(&self, room_id: &RoomId, user_id: &UserId, typing: bool) -> Result<()> { debug_assert!( self.services.globals.user_is_local(user_id), "tried to broadcast typing status of remote user", @@ -197,7 +199,8 @@ impl Service { self.services .sending - .send_edu_room(room_id, serde_json::to_vec(&edu).expect("Serialized Edu::Typing"))?; + .send_edu_room(room_id, serde_json::to_vec(&edu).expect("Serialized Edu::Typing")) + .await?; Ok(()) } diff --git a/src/service/rooms/user/data.rs b/src/service/rooms/user/data.rs index c71316153..d4d9874c2 100644 --- a/src/service/rooms/user/data.rs +++ b/src/service/rooms/user/data.rs @@ -1,8 +1,9 @@ use std::sync::Arc; -use conduit::{utils, Error, Result}; -use database::Map; -use ruma::{OwnedRoomId, OwnedUserId, RoomId, UserId}; +use conduit::Result; +use database::{Deserialized, Map}; +use futures::{Stream, StreamExt}; +use ruma::{RoomId, UserId}; use crate::{globals, rooms, Dep}; @@ -11,13 +12,13 @@ pub(super) struct Data { userroomid_highlightcount: Arc, roomuserid_lastnotificationread: Arc, roomsynctoken_shortstatehash: Arc, - userroomid_joined: Arc, services: Services, } struct Services { globals: Dep, short: Dep, + state_cache: Dep, } impl Data { @@ -28,15 +29,15 @@ impl Data { userroomid_highlightcount: db["userroomid_highlightcount"].clone(), roomuserid_lastnotificationread: db["userroomid_highlightcount"].clone(), //< NOTE: known bug from conduit roomsynctoken_shortstatehash: db["roomsynctoken_shortstatehash"].clone(), - userroomid_joined: db["userroomid_joined"].clone(), services: Services { globals: args.depend::("globals"), short: args.depend::("rooms::short"), + state_cache: args.depend::("rooms::state_cache"), }, } } - pub(super) fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { + pub(super) fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) { let mut userroom_id = user_id.as_bytes().to_vec(); userroom_id.push(0xFF); userroom_id.extend_from_slice(room_id.as_bytes()); @@ -45,128 +46,73 @@ impl Data { roomuser_id.extend_from_slice(user_id.as_bytes()); self.userroomid_notificationcount - .insert(&userroom_id, &0_u64.to_be_bytes())?; + .insert(&userroom_id, &0_u64.to_be_bytes()); self.userroomid_highlightcount - .insert(&userroom_id, &0_u64.to_be_bytes())?; + .insert(&userroom_id, &0_u64.to_be_bytes()); self.roomuserid_lastnotificationread - .insert(&roomuser_id, &self.services.globals.next_count()?.to_be_bytes())?; - - Ok(()) + .insert(&roomuser_id, &self.services.globals.next_count().unwrap().to_be_bytes()); } - pub(super) fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xFF); - userroom_id.extend_from_slice(room_id.as_bytes()); - + pub(super) async fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> u64 { + let key = (user_id, room_id); self.userroomid_notificationcount - .get(&userroom_id)? - .map_or(Ok(0), |bytes| { - utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid notification count in db.")) - }) + .qry(&key) + .await + .deserialized() + .unwrap_or(0) } - pub(super) fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xFF); - userroom_id.extend_from_slice(room_id.as_bytes()); - + pub(super) async fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> u64 { + let key = (user_id, room_id); self.userroomid_highlightcount - .get(&userroom_id)? - .map_or(Ok(0), |bytes| { - utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid highlight count in db.")) - }) + .qry(&key) + .await + .deserialized() + .unwrap_or(0) } - pub(super) fn last_notification_read(&self, user_id: &UserId, room_id: &RoomId) -> Result { - let mut key = room_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(user_id.as_bytes()); - - Ok(self - .roomuserid_lastnotificationread - .get(&key)? - .map(|bytes| { - utils::u64_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Count in roomuserid_lastprivatereadupdate is invalid.")) - }) - .transpose()? - .unwrap_or(0)) + pub(super) async fn last_notification_read(&self, user_id: &UserId, room_id: &RoomId) -> u64 { + let key = (room_id, user_id); + self.roomuserid_lastnotificationread + .qry(&key) + .await + .deserialized() + .unwrap_or(0) } - pub(super) fn associate_token_shortstatehash( - &self, room_id: &RoomId, token: u64, shortstatehash: u64, - ) -> Result<()> { + pub(super) async fn associate_token_shortstatehash(&self, room_id: &RoomId, token: u64, shortstatehash: u64) { let shortroomid = self .services .short - .get_shortroomid(room_id)? + .get_shortroomid(room_id) + .await .expect("room exists"); let mut key = shortroomid.to_be_bytes().to_vec(); key.extend_from_slice(&token.to_be_bytes()); self.roomsynctoken_shortstatehash - .insert(&key, &shortstatehash.to_be_bytes()) + .insert(&key, &shortstatehash.to_be_bytes()); } - pub(super) fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result> { - let shortroomid = self - .services - .short - .get_shortroomid(room_id)? - .expect("room exists"); - - let mut key = shortroomid.to_be_bytes().to_vec(); - key.extend_from_slice(&token.to_be_bytes()); + pub(super) async fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result { + let shortroomid = self.services.short.get_shortroomid(room_id).await?; + let key: &[u64] = &[shortroomid, token]; self.roomsynctoken_shortstatehash - .get(&key)? - .map(|bytes| { - utils::u64_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Invalid shortstatehash in roomsynctoken_shortstatehash")) - }) - .transpose() + .qry(key) + .await + .deserialized() } + //TODO: optimize; replace point-queries with dual iteration pub(super) fn get_shared_rooms<'a>( - &'a self, users: Vec, - ) -> Result> + 'a>> { - let iterators = users.into_iter().map(move |user_id| { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - - self.userroomid_joined - .scan_prefix(prefix) - .map(|(key, _)| { - let roomid_index = key - .iter() - .enumerate() - .find(|(_, &b)| b == 0xFF) - .ok_or_else(|| Error::bad_database("Invalid userroomid_joined in db."))? - .0 - .saturating_add(1); // +1 because the room id starts AFTER the separator - - let room_id = key[roomid_index..].to_vec(); - - Ok::<_, Error>(room_id) - }) - .filter_map(Result::ok) - }); - - // We use the default compare function because keys are sorted correctly (not - // reversed) - Ok(Box::new( - utils::common_elements(iterators, Ord::cmp) - .expect("users is not empty") - .map(|bytes| { - RoomId::parse( - utils::string_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Invalid RoomId bytes in userroomid_joined"))?, - ) - .map_err(|_| Error::bad_database("Invalid RoomId in userroomid_joined.")) - }), - )) + &'a self, user_a: &'a UserId, user_b: &'a UserId, + ) -> impl Stream + Send + 'a { + self.services + .state_cache + .rooms_joined(user_a) + .filter(|room_id| self.services.state_cache.is_joined(user_b, room_id)) } } diff --git a/src/service/rooms/user/mod.rs b/src/service/rooms/user/mod.rs index 93d38470f..d9d90ecf9 100644 --- a/src/service/rooms/user/mod.rs +++ b/src/service/rooms/user/mod.rs @@ -3,7 +3,8 @@ mod data; use std::sync::Arc; use conduit::Result; -use ruma::{OwnedRoomId, OwnedUserId, RoomId, UserId}; +use futures::{pin_mut, Stream, StreamExt}; +use ruma::{RoomId, UserId}; use self::data::Data; @@ -22,32 +23,49 @@ impl crate::Service for Service { } impl Service { - pub fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { - self.db.reset_notification_counts(user_id, room_id) + #[inline] + pub fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) { + self.db.reset_notification_counts(user_id, room_id); } - pub fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result { - self.db.notification_count(user_id, room_id) + #[inline] + pub async fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> u64 { + self.db.notification_count(user_id, room_id).await } - pub fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result { - self.db.highlight_count(user_id, room_id) + #[inline] + pub async fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> u64 { + self.db.highlight_count(user_id, room_id).await } - pub fn last_notification_read(&self, user_id: &UserId, room_id: &RoomId) -> Result { - self.db.last_notification_read(user_id, room_id) + #[inline] + pub async fn last_notification_read(&self, user_id: &UserId, room_id: &RoomId) -> u64 { + self.db.last_notification_read(user_id, room_id).await } - pub fn associate_token_shortstatehash(&self, room_id: &RoomId, token: u64, shortstatehash: u64) -> Result<()> { + #[inline] + pub async fn associate_token_shortstatehash(&self, room_id: &RoomId, token: u64, shortstatehash: u64) { self.db .associate_token_shortstatehash(room_id, token, shortstatehash) + .await; } - pub fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result> { - self.db.get_token_shortstatehash(room_id, token) + #[inline] + pub async fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result { + self.db.get_token_shortstatehash(room_id, token).await } - pub fn get_shared_rooms(&self, users: Vec) -> Result> + '_> { - self.db.get_shared_rooms(users) + #[inline] + pub fn get_shared_rooms<'a>( + &'a self, user_a: &'a UserId, user_b: &'a UserId, + ) -> impl Stream + Send + 'a { + self.db.get_shared_rooms(user_a, user_b) + } + + pub async fn has_shared_rooms<'a>(&'a self, user_a: &'a UserId, user_b: &'a UserId) -> bool { + let get_shared_rooms = self.get_shared_rooms(user_a, user_b); + + pin_mut!(get_shared_rooms); + get_shared_rooms.next().await.is_some() } } diff --git a/src/service/sending/data.rs b/src/service/sending/data.rs index 6c8e2544d..6f4b5b970 100644 --- a/src/service/sending/data.rs +++ b/src/service/sending/data.rs @@ -1,14 +1,21 @@ use std::sync::Arc; -use conduit::{utils, Error, Result}; -use database::{Database, Map}; +use conduit::{ + utils, + utils::{stream::TryIgnore, ReadyExt}, + Error, Result, +}; +use database::{Database, Deserialized, Map}; +use futures::{Stream, StreamExt}; use ruma::{ServerName, UserId}; use super::{Destination, SendingEvent}; use crate::{globals, Dep}; -type OutgoingSendingIter<'a> = Box, Destination, SendingEvent)>> + 'a>; -type SendingEventIter<'a> = Box, SendingEvent)>> + 'a>; +pub(super) type OutgoingItem = (Key, SendingEvent, Destination); +pub(super) type SendingItem = (Key, SendingEvent); +pub(super) type QueueItem = (Key, SendingEvent); +pub(super) type Key = Vec; pub struct Data { servercurrentevent_data: Arc, @@ -36,58 +43,82 @@ impl Data { } } - #[inline] - pub fn active_requests(&self) -> OutgoingSendingIter<'_> { - Box::new( - self.servercurrentevent_data - .iter() - .map(|(key, v)| parse_servercurrentevent(&key, v).map(|(k, e)| (key, k, e))), - ) + pub(super) fn delete_active_request(&self, key: &[u8]) { self.servercurrentevent_data.remove(key); } + + pub(super) async fn delete_all_active_requests_for(&self, destination: &Destination) { + let prefix = destination.get_prefix(); + self.servercurrentevent_data + .raw_keys_prefix(&prefix) + .ignore_err() + .ready_for_each(|key| self.servercurrentevent_data.remove(key)) + .await; } - #[inline] - pub fn active_requests_for<'a>(&'a self, destination: &Destination) -> SendingEventIter<'a> { + pub(super) async fn delete_all_requests_for(&self, destination: &Destination) { let prefix = destination.get_prefix(); - Box::new( - self.servercurrentevent_data - .scan_prefix(prefix) - .map(|(key, v)| parse_servercurrentevent(&key, v).map(|(_, e)| (key, e))), - ) + self.servercurrentevent_data + .raw_keys_prefix(&prefix) + .ignore_err() + .ready_for_each(|key| self.servercurrentevent_data.remove(key)) + .await; + + self.servernameevent_data + .raw_keys_prefix(&prefix) + .ignore_err() + .ready_for_each(|key| self.servernameevent_data.remove(key)) + .await; } - pub(super) fn delete_active_request(&self, key: &[u8]) -> Result<()> { self.servercurrentevent_data.remove(key) } + pub(super) fn mark_as_active(&self, events: &[QueueItem]) { + for (key, e) in events { + if key.is_empty() { + continue; + } - pub(super) fn delete_all_active_requests_for(&self, destination: &Destination) -> Result<()> { - let prefix = destination.get_prefix(); - for (key, _) in self.servercurrentevent_data.scan_prefix(prefix) { - self.servercurrentevent_data.remove(&key)?; + let value = if let SendingEvent::Edu(value) = &e { + &**value + } else { + &[] + }; + self.servercurrentevent_data.insert(key, value); + self.servernameevent_data.remove(key); } + } - Ok(()) + #[inline] + pub fn active_requests(&self) -> impl Stream + Send + '_ { + self.servercurrentevent_data + .raw_stream() + .ignore_err() + .map(|(key, val)| { + let (dest, event) = parse_servercurrentevent(key, val).expect("invalid servercurrentevent"); + + (key.to_vec(), event, dest) + }) } - pub(super) fn delete_all_requests_for(&self, destination: &Destination) -> Result<()> { + #[inline] + pub fn active_requests_for(&self, destination: &Destination) -> impl Stream + Send + '_ { let prefix = destination.get_prefix(); - for (key, _) in self.servercurrentevent_data.scan_prefix(prefix.clone()) { - self.servercurrentevent_data.remove(&key).unwrap(); - } - - for (key, _) in self.servernameevent_data.scan_prefix(prefix) { - self.servernameevent_data.remove(&key).unwrap(); - } + self.servercurrentevent_data + .stream_raw_prefix(&prefix) + .ignore_err() + .map(|(key, val)| { + let (_, event) = parse_servercurrentevent(key, val).expect("invalid servercurrentevent"); - Ok(()) + (key.to_vec(), event) + }) } - pub(super) fn queue_requests(&self, requests: &[(&Destination, SendingEvent)]) -> Result>> { + pub(super) fn queue_requests(&self, requests: &[(&SendingEvent, &Destination)]) -> Vec> { let mut batch = Vec::new(); let mut keys = Vec::new(); - for (destination, event) in requests { + for (event, destination) in requests { let mut key = destination.get_prefix(); if let SendingEvent::Pdu(value) = &event { key.extend_from_slice(value); } else { - key.extend_from_slice(&self.services.globals.next_count()?.to_be_bytes()); + key.extend_from_slice(&self.services.globals.next_count().unwrap().to_be_bytes()); } let value = if let SendingEvent::Edu(value) = &event { &**value @@ -97,56 +128,39 @@ impl Data { batch.push((key.clone(), value.to_owned())); keys.push(key); } - self.servernameevent_data - .insert_batch(batch.iter().map(database::KeyVal::from))?; - Ok(keys) - } - pub fn queued_requests<'a>( - &'a self, destination: &Destination, - ) -> Box)>> + 'a> { - let prefix = destination.get_prefix(); - return Box::new( - self.servernameevent_data - .scan_prefix(prefix) - .map(|(k, v)| parse_servercurrentevent(&k, v).map(|(_, ev)| (ev, k))), - ); + self.servernameevent_data.insert_batch(batch.iter()); + keys } - pub(super) fn mark_as_active(&self, events: &[(SendingEvent, Vec)]) -> Result<()> { - for (e, key) in events { - if key.is_empty() { - continue; - } - - let value = if let SendingEvent::Edu(value) = &e { - &**value - } else { - &[] - }; - self.servercurrentevent_data.insert(key, value)?; - self.servernameevent_data.remove(key)?; - } + pub fn queued_requests(&self, destination: &Destination) -> impl Stream + Send + '_ { + let prefix = destination.get_prefix(); + self.servernameevent_data + .stream_raw_prefix(&prefix) + .ignore_err() + .map(|(key, val)| { + let (_, event) = parse_servercurrentevent(key, val).expect("invalid servercurrentevent"); - Ok(()) + (key.to_vec(), event) + }) } - pub(super) fn set_latest_educount(&self, server_name: &ServerName, last_count: u64) -> Result<()> { + pub(super) fn set_latest_educount(&self, server_name: &ServerName, last_count: u64) { self.servername_educount - .insert(server_name.as_bytes(), &last_count.to_be_bytes()) + .insert(server_name.as_bytes(), &last_count.to_be_bytes()); } - pub fn get_latest_educount(&self, server_name: &ServerName) -> Result { + pub async fn get_latest_educount(&self, server_name: &ServerName) -> u64 { self.servername_educount - .get(server_name.as_bytes())? - .map_or(Ok(0), |bytes| { - utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid u64 in servername_educount.")) - }) + .get(server_name) + .await + .deserialized() + .unwrap_or(0) } } #[tracing::instrument(skip(key), level = "debug")] -fn parse_servercurrentevent(key: &[u8], value: Vec) -> Result<(Destination, SendingEvent)> { +fn parse_servercurrentevent(key: &[u8], value: &[u8]) -> Result<(Destination, SendingEvent)> { // Appservices start with a plus Ok::<_, Error>(if key.starts_with(b"+") { let mut parts = key[1..].splitn(2, |&b| b == 0xFF); @@ -164,7 +178,7 @@ fn parse_servercurrentevent(key: &[u8], value: Vec) -> Result<(Destination, if value.is_empty() { SendingEvent::Pdu(event.to_vec()) } else { - SendingEvent::Edu(value) + SendingEvent::Edu(value.to_vec()) }, ) } else if key.starts_with(b"$") { @@ -192,7 +206,7 @@ fn parse_servercurrentevent(key: &[u8], value: Vec) -> Result<(Destination, SendingEvent::Pdu(event.to_vec()) } else { // I'm pretty sure this should never be called - SendingEvent::Edu(value) + SendingEvent::Edu(value.to_vec()) }, ) } else { @@ -214,7 +228,7 @@ fn parse_servercurrentevent(key: &[u8], value: Vec) -> Result<(Destination, if value.is_empty() { SendingEvent::Pdu(event.to_vec()) } else { - SendingEvent::Edu(value) + SendingEvent::Edu(value.to_vec()) }, ) }) diff --git a/src/service/sending/mod.rs b/src/service/sending/mod.rs index b90ea3618..e3582f2ea 100644 --- a/src/service/sending/mod.rs +++ b/src/service/sending/mod.rs @@ -7,10 +7,11 @@ mod sender; use std::{fmt::Debug, sync::Arc}; use async_trait::async_trait; -use conduit::{err, warn, Result, Server}; +use conduit::{err, utils::ReadyExt, warn, Result, Server}; +use futures::{future::ready, Stream, StreamExt, TryStreamExt}; use ruma::{ api::{appservice::Registration, OutgoingRequest}, - OwnedServerName, RoomId, ServerName, UserId, + RoomId, ServerName, UserId, }; use tokio::sync::Mutex; @@ -104,7 +105,7 @@ impl Service { let dest = Destination::Push(user.to_owned(), pushkey); let event = SendingEvent::Pdu(pdu_id.to_owned()); let _cork = self.db.db.cork(); - let keys = self.db.queue_requests(&[(&dest, event.clone())])?; + let keys = self.db.queue_requests(&[(&event, &dest)]); self.dispatch(Msg { dest, event, @@ -117,7 +118,7 @@ impl Service { let dest = Destination::Appservice(appservice_id); let event = SendingEvent::Pdu(pdu_id); let _cork = self.db.db.cork(); - let keys = self.db.queue_requests(&[(&dest, event.clone())])?; + let keys = self.db.queue_requests(&[(&event, &dest)]); self.dispatch(Msg { dest, event, @@ -126,30 +127,31 @@ impl Service { } #[tracing::instrument(skip(self, room_id, pdu_id), level = "debug")] - pub fn send_pdu_room(&self, room_id: &RoomId, pdu_id: &[u8]) -> Result<()> { + pub async fn send_pdu_room(&self, room_id: &RoomId, pdu_id: &[u8]) -> Result<()> { let servers = self .services .state_cache .room_servers(room_id) - .filter_map(Result::ok) - .filter(|server_name| !self.services.globals.server_is_ours(server_name)); + .ready_filter(|server_name| !self.services.globals.server_is_ours(server_name)); - self.send_pdu_servers(servers, pdu_id) + self.send_pdu_servers(servers, pdu_id).await } #[tracing::instrument(skip(self, servers, pdu_id), level = "debug")] - pub fn send_pdu_servers>(&self, servers: I, pdu_id: &[u8]) -> Result<()> { - let requests = servers - .into_iter() - .map(|server| (Destination::Normal(server), SendingEvent::Pdu(pdu_id.to_owned()))) - .collect::>(); + pub async fn send_pdu_servers<'a, S>(&self, servers: S, pdu_id: &[u8]) -> Result<()> + where + S: Stream + Send + 'a, + { let _cork = self.db.db.cork(); - let keys = self.db.queue_requests( - &requests - .iter() - .map(|(o, e)| (o, e.clone())) - .collect::>(), - )?; + let requests = servers + .map(|server| (Destination::Normal(server.into()), SendingEvent::Pdu(pdu_id.into()))) + .collect::>() + .await; + + let keys = self + .db + .queue_requests(&requests.iter().map(|(o, e)| (e, o)).collect::>()); + for ((dest, event), queue_id) in requests.into_iter().zip(keys) { self.dispatch(Msg { dest, @@ -166,7 +168,7 @@ impl Service { let dest = Destination::Normal(server.to_owned()); let event = SendingEvent::Edu(serialized); let _cork = self.db.db.cork(); - let keys = self.db.queue_requests(&[(&dest, event.clone())])?; + let keys = self.db.queue_requests(&[(&event, &dest)]); self.dispatch(Msg { dest, event, @@ -175,30 +177,30 @@ impl Service { } #[tracing::instrument(skip(self, room_id, serialized), level = "debug")] - pub fn send_edu_room(&self, room_id: &RoomId, serialized: Vec) -> Result<()> { + pub async fn send_edu_room(&self, room_id: &RoomId, serialized: Vec) -> Result<()> { let servers = self .services .state_cache .room_servers(room_id) - .filter_map(Result::ok) - .filter(|server_name| !self.services.globals.server_is_ours(server_name)); + .ready_filter(|server_name| !self.services.globals.server_is_ours(server_name)); - self.send_edu_servers(servers, serialized) + self.send_edu_servers(servers, serialized).await } #[tracing::instrument(skip(self, servers, serialized), level = "debug")] - pub fn send_edu_servers>(&self, servers: I, serialized: Vec) -> Result<()> { - let requests = servers - .into_iter() - .map(|server| (Destination::Normal(server), SendingEvent::Edu(serialized.clone()))) - .collect::>(); + pub async fn send_edu_servers<'a, S>(&self, servers: S, serialized: Vec) -> Result<()> + where + S: Stream + Send + 'a, + { let _cork = self.db.db.cork(); - let keys = self.db.queue_requests( - &requests - .iter() - .map(|(o, e)| (o, e.clone())) - .collect::>(), - )?; + let requests = servers + .map(|server| (Destination::Normal(server.to_owned()), SendingEvent::Edu(serialized.clone()))) + .collect::>() + .await; + + let keys = self + .db + .queue_requests(&requests.iter().map(|(o, e)| (e, o)).collect::>()); for ((dest, event), queue_id) in requests.into_iter().zip(keys) { self.dispatch(Msg { @@ -212,29 +214,33 @@ impl Service { } #[tracing::instrument(skip(self, room_id), level = "debug")] - pub fn flush_room(&self, room_id: &RoomId) -> Result<()> { + pub async fn flush_room(&self, room_id: &RoomId) -> Result<()> { let servers = self .services .state_cache .room_servers(room_id) - .filter_map(Result::ok) - .filter(|server_name| !self.services.globals.server_is_ours(server_name)); + .ready_filter(|server_name| !self.services.globals.server_is_ours(server_name)); - self.flush_servers(servers) + self.flush_servers(servers).await } #[tracing::instrument(skip(self, servers), level = "debug")] - pub fn flush_servers>(&self, servers: I) -> Result<()> { - let requests = servers.into_iter().map(Destination::Normal); - for dest in requests { - self.dispatch(Msg { - dest, - event: SendingEvent::Flush, - queue_id: Vec::::new(), - })?; - } - - Ok(()) + pub async fn flush_servers<'a, S>(&self, servers: S) -> Result<()> + where + S: Stream + Send + 'a, + { + servers + .map(ToOwned::to_owned) + .map(Destination::Normal) + .map(Ok) + .try_for_each(|dest| { + ready(self.dispatch(Msg { + dest, + event: SendingEvent::Flush, + queue_id: Vec::::new(), + })) + }) + .await } #[tracing::instrument(skip_all, name = "request")] @@ -263,11 +269,10 @@ impl Service { /// Cleanup event data /// Used for instance after we remove an appservice registration #[tracing::instrument(skip(self), level = "debug")] - pub fn cleanup_events(&self, appservice_id: String) -> Result<()> { + pub async fn cleanup_events(&self, appservice_id: String) { self.db - .delete_all_requests_for(&Destination::Appservice(appservice_id))?; - - Ok(()) + .delete_all_requests_for(&Destination::Appservice(appservice_id)) + .await; } fn dispatch(&self, msg: Msg) -> Result<()> { diff --git a/src/service/sending/sender.rs b/src/service/sending/sender.rs index 206bf92bb..4db9922ae 100644 --- a/src/service/sending/sender.rs +++ b/src/service/sending/sender.rs @@ -7,18 +7,15 @@ use std::{ use base64::{engine::general_purpose, Engine as _}; use conduit::{ - debug, debug_warn, error, trace, - utils::{calculate_hash, math::continue_exponential_backoff_secs}, + debug, debug_warn, err, trace, + utils::{calculate_hash, math::continue_exponential_backoff_secs, ReadyExt}, warn, Error, Result, }; -use federation::transactions::send_transaction_message; -use futures_util::{future::BoxFuture, stream::FuturesUnordered, StreamExt}; +use futures::{future::BoxFuture, pin_mut, stream::FuturesUnordered, FutureExt, StreamExt}; use ruma::{ - api::federation::{ - self, - transactions::edu::{ - DeviceListUpdateContent, Edu, PresenceContent, PresenceUpdate, ReceiptContent, ReceiptData, ReceiptMap, - }, + api::federation::transactions::{ + edu::{DeviceListUpdateContent, Edu, PresenceContent, PresenceUpdate, ReceiptContent, ReceiptData, ReceiptMap}, + send_transaction_message, }, device_id, events::{push_rules::PushRulesEvent, receipt::ReceiptType, AnySyncEphemeralRoomEvent, GlobalAccountDataEventType}, @@ -28,7 +25,7 @@ use ruma::{ use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; use tokio::time::sleep_until; -use super::{appservice, Destination, Msg, SendingEvent, Service}; +use super::{appservice, data::QueueItem, Destination, Msg, SendingEvent, Service}; #[derive(Debug)] enum TransactionStatus { @@ -50,20 +47,20 @@ const CLEANUP_TIMEOUT_MS: u64 = 3500; impl Service { #[tracing::instrument(skip_all, name = "sender")] pub(super) async fn sender(&self) -> Result<()> { - let receiver = self.receiver.lock().await; - let mut futures: SendingFutures<'_> = FuturesUnordered::new(); let mut statuses: CurTransactionStatus = CurTransactionStatus::new(); + let mut futures: SendingFutures<'_> = FuturesUnordered::new(); + let receiver = self.receiver.lock().await; - self.initial_requests(&futures, &mut statuses); + self.initial_requests(&mut futures, &mut statuses).await; loop { debug_assert!(!receiver.is_closed(), "channel error"); tokio::select! { request = receiver.recv_async() => match request { - Ok(request) => self.handle_request(request, &futures, &mut statuses), + Ok(request) => self.handle_request(request, &mut futures, &mut statuses).await, Err(_) => break, }, Some(response) = futures.next() => { - self.handle_response(response, &futures, &mut statuses); + self.handle_response(response, &mut futures, &mut statuses).await; }, } } @@ -72,18 +69,16 @@ impl Service { Ok(()) } - fn handle_response<'a>( - &'a self, response: SendingResult, futures: &SendingFutures<'a>, statuses: &mut CurTransactionStatus, + async fn handle_response<'a>( + &'a self, response: SendingResult, futures: &mut SendingFutures<'a>, statuses: &mut CurTransactionStatus, ) { match response { - Ok(dest) => self.handle_response_ok(&dest, futures, statuses), - Err((dest, e)) => Self::handle_response_err(dest, futures, statuses, &e), + Ok(dest) => self.handle_response_ok(&dest, futures, statuses).await, + Err((dest, e)) => Self::handle_response_err(dest, statuses, &e), }; } - fn handle_response_err( - dest: Destination, _futures: &SendingFutures<'_>, statuses: &mut CurTransactionStatus, e: &Error, - ) { + fn handle_response_err(dest: Destination, statuses: &mut CurTransactionStatus, e: &Error) { debug!(dest = ?dest, "{e:?}"); statuses.entry(dest).and_modify(|e| { *e = match e { @@ -94,39 +89,40 @@ impl Service { }); } - fn handle_response_ok<'a>( - &'a self, dest: &Destination, futures: &SendingFutures<'a>, statuses: &mut CurTransactionStatus, + #[allow(clippy::needless_pass_by_ref_mut)] + async fn handle_response_ok<'a>( + &'a self, dest: &Destination, futures: &mut SendingFutures<'a>, statuses: &mut CurTransactionStatus, ) { let _cork = self.db.db.cork(); - self.db - .delete_all_active_requests_for(dest) - .expect("all active requests deleted"); + self.db.delete_all_active_requests_for(dest).await; // Find events that have been added since starting the last request let new_events = self .db .queued_requests(dest) - .filter_map(Result::ok) .take(DEQUEUE_LIMIT) - .collect::>(); + .collect::>() + .await; // Insert any pdus we found if !new_events.is_empty() { - self.db - .mark_as_active(&new_events) - .expect("marked as active"); - let new_events_vec = new_events.into_iter().map(|(event, _)| event).collect(); - futures.push(Box::pin(self.send_events(dest.clone(), new_events_vec))); + self.db.mark_as_active(&new_events); + + let new_events_vec = new_events.into_iter().map(|(_, event)| event).collect(); + futures.push(self.send_events(dest.clone(), new_events_vec).boxed()); } else { statuses.remove(dest); } } - fn handle_request<'a>(&'a self, msg: Msg, futures: &SendingFutures<'a>, statuses: &mut CurTransactionStatus) { - let iv = vec![(msg.event, msg.queue_id)]; - if let Ok(Some(events)) = self.select_events(&msg.dest, iv, statuses) { + #[allow(clippy::needless_pass_by_ref_mut)] + async fn handle_request<'a>( + &'a self, msg: Msg, futures: &mut SendingFutures<'a>, statuses: &mut CurTransactionStatus, + ) { + let iv = vec![(msg.queue_id, msg.event)]; + if let Ok(Some(events)) = self.select_events(&msg.dest, iv, statuses).await { if !events.is_empty() { - futures.push(Box::pin(self.send_events(msg.dest, events))); + futures.push(self.send_events(msg.dest, events).boxed()); } else { statuses.remove(&msg.dest); } @@ -142,7 +138,7 @@ impl Service { tokio::select! { () = sleep_until(deadline.into()) => break, response = futures.next() => match response { - Some(response) => self.handle_response(response, futures, statuses), + Some(response) => self.handle_response(response, futures, statuses).await, None => return, } } @@ -151,16 +147,17 @@ impl Service { debug_warn!("Leaving with {} unfinished requests...", futures.len()); } - fn initial_requests<'a>(&'a self, futures: &SendingFutures<'a>, statuses: &mut CurTransactionStatus) { + #[allow(clippy::needless_pass_by_ref_mut)] + async fn initial_requests<'a>(&'a self, futures: &mut SendingFutures<'a>, statuses: &mut CurTransactionStatus) { let keep = usize::try_from(self.server.config.startup_netburst_keep).unwrap_or(usize::MAX); let mut txns = HashMap::>::new(); - for (key, dest, event) in self.db.active_requests().filter_map(Result::ok) { + let mut active = self.db.active_requests().boxed(); + + while let Some((key, event, dest)) = active.next().await { let entry = txns.entry(dest.clone()).or_default(); if self.server.config.startup_netburst_keep >= 0 && entry.len() >= keep { - warn!("Dropping unsent event {:?} {:?}", dest, String::from_utf8_lossy(&key)); - self.db - .delete_active_request(&key) - .expect("active request deleted"); + warn!("Dropping unsent event {dest:?} {:?}", String::from_utf8_lossy(&key)); + self.db.delete_active_request(&key); } else { entry.push(event); } @@ -169,16 +166,16 @@ impl Service { for (dest, events) in txns { if self.server.config.startup_netburst && !events.is_empty() { statuses.insert(dest.clone(), TransactionStatus::Running); - futures.push(Box::pin(self.send_events(dest.clone(), events))); + futures.push(self.send_events(dest.clone(), events).boxed()); } } } #[tracing::instrument(skip_all, level = "debug")] - fn select_events( + async fn select_events( &self, dest: &Destination, - new_events: Vec<(SendingEvent, Vec)>, // Events we want to send: event and full key + new_events: Vec, // Events we want to send: event and full key statuses: &mut CurTransactionStatus, ) -> Result>> { let (allow, retry) = self.select_events_current(dest.clone(), statuses)?; @@ -195,8 +192,8 @@ impl Service { if retry { self.db .active_requests_for(dest) - .filter_map(Result::ok) - .for_each(|(_, e)| events.push(e)); + .ready_for_each(|(_, e)| events.push(e)) + .await; return Ok(Some(events)); } @@ -204,17 +201,17 @@ impl Service { // Compose the next transaction let _cork = self.db.db.cork(); if !new_events.is_empty() { - self.db.mark_as_active(&new_events)?; - for (e, _) in new_events { + self.db.mark_as_active(&new_events); + for (_, e) in new_events { events.push(e); } } // Add EDU's into the transaction if let Destination::Normal(server_name) = dest { - if let Ok((select_edus, last_count)) = self.select_edus(server_name) { + if let Ok((select_edus, last_count)) = self.select_edus(server_name).await { events.extend(select_edus.into_iter().map(SendingEvent::Edu)); - self.db.set_latest_educount(server_name, last_count)?; + self.db.set_latest_educount(server_name, last_count); } } @@ -248,26 +245,32 @@ impl Service { } #[tracing::instrument(skip_all, level = "debug")] - fn select_edus(&self, server_name: &ServerName) -> Result<(Vec>, u64)> { + async fn select_edus(&self, server_name: &ServerName) -> Result<(Vec>, u64)> { // u64: count of last edu - let since = self.db.get_latest_educount(server_name)?; + let since = self.db.get_latest_educount(server_name).await; let mut events = Vec::new(); let mut max_edu_count = since; let mut device_list_changes = HashSet::new(); - for room_id in self.services.state_cache.server_rooms(server_name) { - let room_id = room_id?; + let server_rooms = self.services.state_cache.server_rooms(server_name); + + pin_mut!(server_rooms); + while let Some(room_id) = server_rooms.next().await { // Look for device list updates in this room device_list_changes.extend( self.services .users - .keys_changed(room_id.as_ref(), since, None) - .filter_map(Result::ok) - .filter(|user_id| self.services.globals.user_is_local(user_id)), + .keys_changed(room_id.as_str(), since, None) + .ready_filter(|user_id| self.services.globals.user_is_local(user_id)) + .map(ToOwned::to_owned) + .collect::>() + .await, ); if self.server.config.allow_outgoing_read_receipts - && !self.select_edus_receipts(&room_id, since, &mut max_edu_count, &mut events)? + && !self + .select_edus_receipts(room_id, since, &mut max_edu_count, &mut events) + .await? { break; } @@ -290,19 +293,22 @@ impl Service { } if self.server.config.allow_outgoing_presence { - self.select_edus_presence(server_name, since, &mut max_edu_count, &mut events)?; + self.select_edus_presence(server_name, since, &mut max_edu_count, &mut events) + .await?; } Ok((events, max_edu_count)) } /// Look for presence - fn select_edus_presence( + async fn select_edus_presence( &self, server_name: &ServerName, since: u64, max_edu_count: &mut u64, events: &mut Vec>, ) -> Result { - // Look for presence updates for this server + let presence_since = self.services.presence.presence_since(since); + + pin_mut!(presence_since); let mut presence_updates = Vec::new(); - for (user_id, count, presence_bytes) in self.services.presence.presence_since(since) { + while let Some((user_id, count, presence_bytes)) = presence_since.next().await { *max_edu_count = cmp::max(count, *max_edu_count); if !self.services.globals.user_is_local(&user_id) { @@ -312,7 +318,8 @@ impl Service { if !self .services .state_cache - .server_sees_user(server_name, &user_id)? + .server_sees_user(server_name, &user_id) + .await { continue; } @@ -320,7 +327,9 @@ impl Service { let presence_event = self .services .presence - .from_json_bytes_to_event(&presence_bytes, &user_id)?; + .from_json_bytes_to_event(&presence_bytes, &user_id) + .await?; + presence_updates.push(PresenceUpdate { user_id, presence: presence_event.content.presence, @@ -346,32 +355,33 @@ impl Service { } /// Look for read receipts in this room - fn select_edus_receipts( + async fn select_edus_receipts( &self, room_id: &RoomId, since: u64, max_edu_count: &mut u64, events: &mut Vec>, ) -> Result { - for r in self + let receipts = self .services .read_receipt - .readreceipts_since(room_id, since) - { - let (user_id, count, read_receipt) = r?; - *max_edu_count = cmp::max(count, *max_edu_count); + .readreceipts_since(room_id, since); + pin_mut!(receipts); + while let Some((user_id, count, read_receipt)) = receipts.next().await { + *max_edu_count = cmp::max(count, *max_edu_count); if !self.services.globals.user_is_local(&user_id) { continue; } let event = serde_json::from_str(read_receipt.json().get()) .map_err(|_| Error::bad_database("Invalid edu event in read_receipts."))?; + let federation_event = if let AnySyncEphemeralRoomEvent::Receipt(r) = event { let mut read = BTreeMap::new(); - let (event_id, mut receipt) = r .content .0 .into_iter() .next() .expect("we only use one event per read receipt"); + let receipt = receipt .remove(&ReceiptType::Read) .expect("our read receipts always set this") @@ -427,24 +437,17 @@ impl Service { async fn send_events_dest_appservice( &self, dest: &Destination, id: &str, events: Vec, ) -> SendingResult { - let mut pdu_jsons = Vec::new(); + let Some(appservice) = self.services.appservice.get_registration(id).await else { + return Err((dest.clone(), err!(Database(warn!(?id, "Missing appservice registration"))))); + }; + let mut pdu_jsons = Vec::new(); for event in &events { match event { SendingEvent::Pdu(pdu_id) => { - pdu_jsons.push( - self.services - .timeline - .get_pdu_from_id(pdu_id) - .map_err(|e| (dest.clone(), e))? - .ok_or_else(|| { - ( - dest.clone(), - Error::bad_database("[Appservice] Event in servernameevent_data not found in db."), - ) - })? - .to_room_event(), - ); + if let Ok(pdu) = self.services.timeline.get_pdu_from_id(pdu_id).await { + pdu_jsons.push(pdu.to_room_event()); + } }, SendingEvent::Edu(_) | SendingEvent::Flush => { // Appservices don't need EDUs (?) and flush only; @@ -453,32 +456,24 @@ impl Service { } } + let txn_id = &*general_purpose::URL_SAFE_NO_PAD.encode(calculate_hash( + &events + .iter() + .map(|e| match e { + SendingEvent::Edu(b) | SendingEvent::Pdu(b) => &**b, + SendingEvent::Flush => &[], + }) + .collect::>(), + )); + //debug_assert!(!pdu_jsons.is_empty(), "sending empty transaction"); let client = &self.services.client.appservice; match appservice::send_request( client, - self.services - .appservice - .get_registration(id) - .await - .ok_or_else(|| { - ( - dest.clone(), - Error::bad_database("[Appservice] Could not load registration from db."), - ) - })?, + appservice, ruma::api::appservice::event::push_events::v1::Request { events: pdu_jsons, - txn_id: (&*general_purpose::URL_SAFE_NO_PAD.encode(calculate_hash( - &events - .iter() - .map(|e| match e { - SendingEvent::Edu(b) | SendingEvent::Pdu(b) => &**b, - SendingEvent::Flush => &[], - }) - .collect::>(), - ))) - .into(), + txn_id: txn_id.into(), ephemeral: Vec::new(), to_device: Vec::new(), }, @@ -494,23 +489,17 @@ impl Service { async fn send_events_dest_push( &self, dest: &Destination, userid: &OwnedUserId, pushkey: &str, events: Vec, ) -> SendingResult { - let mut pdus = Vec::new(); + let Ok(pusher) = self.services.pusher.get_pusher(userid, pushkey).await else { + return Err((dest.clone(), err!(Database(error!(?userid, ?pushkey, "Missing pusher"))))); + }; + let mut pdus = Vec::new(); for event in &events { match event { SendingEvent::Pdu(pdu_id) => { - pdus.push( - self.services - .timeline - .get_pdu_from_id(pdu_id) - .map_err(|e| (dest.clone(), e))? - .ok_or_else(|| { - ( - dest.clone(), - Error::bad_database("[Push] Event in servernameevent_data not found in db."), - ) - })?, - ); + if let Ok(pdu) = self.services.timeline.get_pdu_from_id(pdu_id).await { + pdus.push(pdu); + } }, SendingEvent::Edu(_) | SendingEvent::Flush => { // Push gateways don't need EDUs (?) and flush only; @@ -529,28 +518,22 @@ impl Service { } } - let Some(pusher) = self - .services - .pusher - .get_pusher(userid, pushkey) - .map_err(|e| (dest.clone(), e))? - else { - continue; - }; - let rules_for_user = self .services .account_data .get(None, userid, GlobalAccountDataEventType::PushRules.to_string().into()) - .unwrap_or_default() - .and_then(|event| serde_json::from_str::(event.get()).ok()) - .map_or_else(|| push::Ruleset::server_default(userid), |ev: PushRulesEvent| ev.content.global); + .await + .and_then(|event| serde_json::from_str::(event.get()).map_err(Into::into)) + .map_or_else( + |_| push::Ruleset::server_default(userid), + |ev: PushRulesEvent| ev.content.global, + ); let unread: UInt = self .services .user .notification_count(userid, &pdu.room_id) - .map_err(|e| (dest.clone(), e))? + .await .try_into() .expect("notification count can't go that high"); @@ -559,7 +542,6 @@ impl Service { .pusher .send_push_notice(userid, unread, &pusher, rules_for_user, &pdu) .await - .map(|_response| dest.clone()) .map_err(|e| (dest.clone(), e)); } @@ -586,21 +568,11 @@ impl Service { for event in &events { match event { // TODO: check room version and remove event_id if needed - SendingEvent::Pdu(pdu_id) => pdu_jsons.push( - self.convert_to_outgoing_federation_event( - self.services - .timeline - .get_pdu_json_from_id(pdu_id) - .map_err(|e| (dest.clone(), e))? - .ok_or_else(|| { - error!(?dest, ?server, ?pdu_id, "event not found"); - ( - dest.clone(), - Error::bad_database("[Normal] Event in servernameevent_data not found in db."), - ) - })?, - ), - ), + SendingEvent::Pdu(pdu_id) => { + if let Ok(pdu) = self.services.timeline.get_pdu_json_from_id(pdu_id).await { + pdu_jsons.push(self.convert_to_outgoing_federation_event(pdu).await); + } + }, SendingEvent::Edu(edu) => { if let Ok(raw) = serde_json::from_slice(edu) { edu_jsons.push(raw); @@ -647,7 +619,7 @@ impl Service { } /// This does not return a full `Pdu` it is only to satisfy ruma's types. - pub fn convert_to_outgoing_federation_event(&self, mut pdu_json: CanonicalJsonObject) -> Box { + pub async fn convert_to_outgoing_federation_event(&self, mut pdu_json: CanonicalJsonObject) -> Box { if let Some(unsigned) = pdu_json .get_mut("unsigned") .and_then(|val| val.as_object_mut()) @@ -660,7 +632,7 @@ impl Service { .get("room_id") .and_then(|val| RoomId::parse(val.as_str()?).ok()) { - match self.services.state.get_room_version(&room_id) { + match self.services.state.get_room_version(&room_id).await { Ok(room_version_id) => match room_version_id { RoomVersionId::V1 | RoomVersionId::V2 => {}, _ => _ = pdu_json.remove("event_id"), diff --git a/src/service/server_keys/mod.rs b/src/service/server_keys/mod.rs index a565e5009..ae2b8c3cb 100644 --- a/src/service/server_keys/mod.rs +++ b/src/service/server_keys/mod.rs @@ -5,7 +5,7 @@ use std::{ }; use conduit::{debug, debug_error, debug_warn, err, error, info, trace, warn, Err, Result}; -use futures_util::{stream::FuturesUnordered, StreamExt}; +use futures::{stream::FuturesUnordered, StreamExt}; use ruma::{ api::federation::{ discovery::{ @@ -179,7 +179,8 @@ impl Service { let result: BTreeMap<_, _> = self .services .globals - .verify_keys_for(origin)? + .verify_keys_for(origin) + .await? .into_iter() .map(|(k, v)| (k.to_string(), v.key)) .collect(); @@ -236,7 +237,8 @@ impl Service { .services .globals .db - .add_signing_key(&k.server_name, k.clone())? + .add_signing_key(&k.server_name, k.clone()) + .await .into_iter() .map(|(k, v)| (k.to_string(), v.key)) .collect::>(); @@ -283,7 +285,8 @@ impl Service { .services .globals .db - .add_signing_key(&origin, key)? + .add_signing_key(&origin, key) + .await .into_iter() .map(|(k, v)| (k.to_string(), v.key)) .collect(); @@ -384,7 +387,8 @@ impl Service { let mut result: BTreeMap<_, _> = self .services .globals - .verify_keys_for(origin)? + .verify_keys_for(origin) + .await? .into_iter() .map(|(k, v)| (k.to_string(), v.key)) .collect(); @@ -431,7 +435,8 @@ impl Service { self.services .globals .db - .add_signing_key(origin, k.clone())?; + .add_signing_key(origin, k.clone()) + .await; result.extend( k.verify_keys .into_iter() @@ -462,7 +467,8 @@ impl Service { self.services .globals .db - .add_signing_key(origin, server_key.clone())?; + .add_signing_key(origin, server_key.clone()) + .await; result.extend( server_key @@ -495,7 +501,8 @@ impl Service { self.services .globals .db - .add_signing_key(origin, server_key.clone())?; + .add_signing_key(origin, server_key.clone()) + .await; result.extend( server_key @@ -545,7 +552,8 @@ impl Service { self.services .globals .db - .add_signing_key(origin, k.clone())?; + .add_signing_key(origin, k.clone()) + .await; result.extend( k.verify_keys .into_iter() diff --git a/src/service/service.rs b/src/service/service.rs index 635f782ea..031650506 100644 --- a/src/service/service.rs +++ b/src/service/service.rs @@ -7,7 +7,7 @@ use std::{ }; use async_trait::async_trait; -use conduit::{err, error::inspect_log, utils::string::split_once_infallible, Err, Result, Server}; +use conduit::{err, error::inspect_log, utils::string::SplitInfallible, Err, Result, Server}; use database::Database; /// Abstract interface for a Service @@ -51,7 +51,7 @@ pub(crate) struct Args<'a> { /// Dep is a reference to a service used within another service. /// Circular-dependencies between services require this indirection. -pub(crate) struct Dep { +pub(crate) struct Dep { dep: OnceLock>, service: Weak, name: &'static str, @@ -62,7 +62,7 @@ pub(crate) type MapType = BTreeMap; pub(crate) type MapVal = (Weak, Weak); pub(crate) type MapKey = String; -impl Deref for Dep { +impl Deref for Dep { type Target = Arc; /// Dereference a dependency. The dependency must be ready or panics. @@ -80,7 +80,7 @@ impl Deref for Dep { impl<'a> Args<'a> { /// Create a lazy-reference to a service when constructing another Service. - pub(crate) fn depend(&'a self, name: &'static str) -> Dep { + pub(crate) fn depend(&'a self, name: &'static str) -> Dep { Dep:: { dep: OnceLock::new(), service: Arc::downgrade(self.service), @@ -90,17 +90,12 @@ impl<'a> Args<'a> { /// Create a reference immediately to a service when constructing another /// Service. The other service must be constructed. - pub(crate) fn require(&'a self, name: &'static str) -> Arc { - require::(self.service, name) - } + pub(crate) fn require(&'a self, name: &str) -> Arc { require::(self.service, name) } } /// Reference a Service by name. Panics if the Service does not exist or was /// incorrectly cast. -pub(crate) fn require<'a, 'b, T>(map: &'b Map, name: &'a str) -> Arc -where - T: Send + Sync + 'a + 'b + 'static, -{ +pub(crate) fn require(map: &Map, name: &str) -> Arc { try_get::(map, name) .inspect_err(inspect_log) .expect("Failure to reference service required by another service.") @@ -112,9 +107,9 @@ where /// # Panics /// Incorrect type is not a silent failure (None) as the type never has a reason /// to be incorrect. -pub(crate) fn get<'a, 'b, T>(map: &'b Map, name: &'a str) -> Option> +pub(crate) fn get(map: &Map, name: &str) -> Option> where - T: Send + Sync + 'a + 'b + 'static, + T: Any + Send + Sync + Sized, { map.read() .expect("locked for reading") @@ -129,9 +124,9 @@ where /// Reference a Service by name. Returns Err if the Service does not exist or /// was incorrectly cast. -pub(crate) fn try_get<'a, 'b, T>(map: &'b Map, name: &'a str) -> Result> +pub(crate) fn try_get(map: &Map, name: &str) -> Result> where - T: Send + Sync + 'a + 'b + 'static, + T: Any + Send + Sync + Sized, { map.read() .expect("locked for reading") @@ -152,4 +147,4 @@ where /// Utility for service implementations; see Service::name() in the trait. #[inline] -pub(crate) fn make_name(module_path: &str) -> &str { split_once_infallible(module_path, "::").1 } +pub(crate) fn make_name(module_path: &str) -> &str { module_path.split_once_infallible("::").1 } diff --git a/src/service/services.rs b/src/service/services.rs index 8e69cdbb6..da22fb2d4 100644 --- a/src/service/services.rs +++ b/src/service/services.rs @@ -14,7 +14,7 @@ use crate::{ manager::Manager, media, presence, pusher, resolver, rooms, sending, server_keys, service, service::{Args, Map, Service}, - transaction_ids, uiaa, updates, users, + sync, transaction_ids, uiaa, updates, users, }; pub struct Services { @@ -32,6 +32,7 @@ pub struct Services { pub rooms: rooms::Service, pub sending: Arc, pub server_keys: Arc, + pub sync: Arc, pub transaction_ids: Arc, pub uiaa: Arc, pub updates: Arc, @@ -96,6 +97,7 @@ impl Services { }, sending: build!(sending::Service), server_keys: build!(server_keys::Service), + sync: build!(sync::Service), transaction_ids: build!(transaction_ids::Service), uiaa: build!(uiaa::Service), updates: build!(updates::Service), @@ -193,16 +195,16 @@ impl Services { } } - pub fn try_get<'a, 'b, T>(&'b self, name: &'a str) -> Result> + pub fn try_get(&self, name: &str) -> Result> where - T: Send + Sync + 'a + 'b + 'static, + T: Any + Send + Sync + Sized, { service::try_get::(&self.service, name) } - pub fn get<'a, 'b, T>(&'b self, name: &'a str) -> Option> + pub fn get(&self, name: &str) -> Option> where - T: Send + Sync + 'a + 'b + 'static, + T: Any + Send + Sync + Sized, { service::get::(&self.service, name) } diff --git a/src/service/sync/mod.rs b/src/service/sync/mod.rs new file mode 100644 index 000000000..1bf4610ff --- /dev/null +++ b/src/service/sync/mod.rs @@ -0,0 +1,233 @@ +use std::{ + collections::{BTreeMap, BTreeSet}, + sync::{Arc, Mutex, Mutex as StdMutex}, +}; + +use conduit::Result; +use ruma::{ + api::client::sync::sync_events::{ + self, + v4::{ExtensionsConfig, SyncRequestList}, + }, + OwnedDeviceId, OwnedRoomId, OwnedUserId, +}; + +pub struct Service { + connections: DbConnections, +} + +struct SlidingSyncCache { + lists: BTreeMap, + subscriptions: BTreeMap, + known_rooms: BTreeMap>, // For every room, the roomsince number + extensions: ExtensionsConfig, +} + +type DbConnections = Mutex>; +type DbConnectionsKey = (OwnedUserId, OwnedDeviceId, String); +type DbConnectionsVal = Arc>; + +impl crate::Service for Service { + fn build(_args: crate::Args<'_>) -> Result> { + Ok(Arc::new(Self { + connections: StdMutex::new(BTreeMap::new()), + })) + } + + fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } +} + +impl Service { + pub fn remembered(&self, user_id: OwnedUserId, device_id: OwnedDeviceId, conn_id: String) -> bool { + self.connections + .lock() + .unwrap() + .contains_key(&(user_id, device_id, conn_id)) + } + + pub fn forget_sync_request_connection(&self, user_id: OwnedUserId, device_id: OwnedDeviceId, conn_id: String) { + self.connections + .lock() + .expect("locked") + .remove(&(user_id, device_id, conn_id)); + } + + pub fn update_sync_request_with_cache( + &self, user_id: OwnedUserId, device_id: OwnedDeviceId, request: &mut sync_events::v4::Request, + ) -> BTreeMap> { + let Some(conn_id) = request.conn_id.clone() else { + return BTreeMap::new(); + }; + + let mut cache = self.connections.lock().expect("locked"); + let cached = Arc::clone( + cache + .entry((user_id, device_id, conn_id)) + .or_insert_with(|| { + Arc::new(Mutex::new(SlidingSyncCache { + lists: BTreeMap::new(), + subscriptions: BTreeMap::new(), + known_rooms: BTreeMap::new(), + extensions: ExtensionsConfig::default(), + })) + }), + ); + let cached = &mut cached.lock().expect("locked"); + drop(cache); + + for (list_id, list) in &mut request.lists { + if let Some(cached_list) = cached.lists.get(list_id) { + if list.sort.is_empty() { + list.sort.clone_from(&cached_list.sort); + }; + if list.room_details.required_state.is_empty() { + list.room_details + .required_state + .clone_from(&cached_list.room_details.required_state); + }; + list.room_details.timeline_limit = list + .room_details + .timeline_limit + .or(cached_list.room_details.timeline_limit); + list.include_old_rooms = list + .include_old_rooms + .clone() + .or_else(|| cached_list.include_old_rooms.clone()); + match (&mut list.filters, cached_list.filters.clone()) { + (Some(list_filters), Some(cached_filters)) => { + list_filters.is_dm = list_filters.is_dm.or(cached_filters.is_dm); + if list_filters.spaces.is_empty() { + list_filters.spaces = cached_filters.spaces; + } + list_filters.is_encrypted = list_filters.is_encrypted.or(cached_filters.is_encrypted); + list_filters.is_invite = list_filters.is_invite.or(cached_filters.is_invite); + if list_filters.room_types.is_empty() { + list_filters.room_types = cached_filters.room_types; + } + if list_filters.not_room_types.is_empty() { + list_filters.not_room_types = cached_filters.not_room_types; + } + list_filters.room_name_like = list_filters + .room_name_like + .clone() + .or(cached_filters.room_name_like); + if list_filters.tags.is_empty() { + list_filters.tags = cached_filters.tags; + } + if list_filters.not_tags.is_empty() { + list_filters.not_tags = cached_filters.not_tags; + } + }, + (_, Some(cached_filters)) => list.filters = Some(cached_filters), + (Some(list_filters), _) => list.filters = Some(list_filters.clone()), + (..) => {}, + } + if list.bump_event_types.is_empty() { + list.bump_event_types + .clone_from(&cached_list.bump_event_types); + }; + } + cached.lists.insert(list_id.clone(), list.clone()); + } + + cached + .subscriptions + .extend(request.room_subscriptions.clone()); + request + .room_subscriptions + .extend(cached.subscriptions.clone()); + + request.extensions.e2ee.enabled = request + .extensions + .e2ee + .enabled + .or(cached.extensions.e2ee.enabled); + + request.extensions.to_device.enabled = request + .extensions + .to_device + .enabled + .or(cached.extensions.to_device.enabled); + + request.extensions.account_data.enabled = request + .extensions + .account_data + .enabled + .or(cached.extensions.account_data.enabled); + request.extensions.account_data.lists = request + .extensions + .account_data + .lists + .clone() + .or_else(|| cached.extensions.account_data.lists.clone()); + request.extensions.account_data.rooms = request + .extensions + .account_data + .rooms + .clone() + .or_else(|| cached.extensions.account_data.rooms.clone()); + + cached.extensions = request.extensions.clone(); + + cached.known_rooms.clone() + } + + pub fn update_sync_subscriptions( + &self, user_id: OwnedUserId, device_id: OwnedDeviceId, conn_id: String, + subscriptions: BTreeMap, + ) { + let mut cache = self.connections.lock().expect("locked"); + let cached = Arc::clone( + cache + .entry((user_id, device_id, conn_id)) + .or_insert_with(|| { + Arc::new(Mutex::new(SlidingSyncCache { + lists: BTreeMap::new(), + subscriptions: BTreeMap::new(), + known_rooms: BTreeMap::new(), + extensions: ExtensionsConfig::default(), + })) + }), + ); + let cached = &mut cached.lock().expect("locked"); + drop(cache); + + cached.subscriptions = subscriptions; + } + + pub fn update_sync_known_rooms( + &self, user_id: OwnedUserId, device_id: OwnedDeviceId, conn_id: String, list_id: String, + new_cached_rooms: BTreeSet, globalsince: u64, + ) { + let mut cache = self.connections.lock().expect("locked"); + let cached = Arc::clone( + cache + .entry((user_id, device_id, conn_id)) + .or_insert_with(|| { + Arc::new(Mutex::new(SlidingSyncCache { + lists: BTreeMap::new(), + subscriptions: BTreeMap::new(), + known_rooms: BTreeMap::new(), + extensions: ExtensionsConfig::default(), + })) + }), + ); + let cached = &mut cached.lock().expect("locked"); + drop(cache); + + for (roomid, lastsince) in cached + .known_rooms + .entry(list_id.clone()) + .or_default() + .iter_mut() + { + if !new_cached_rooms.contains(roomid) { + *lastsince = 0; + } + } + let list = cached.known_rooms.entry(list_id).or_default(); + for roomid in new_cached_rooms { + list.insert(roomid, globalsince); + } + } +} diff --git a/src/service/transaction_ids/data.rs b/src/service/transaction_ids/data.rs deleted file mode 100644 index 791b46f01..000000000 --- a/src/service/transaction_ids/data.rs +++ /dev/null @@ -1,44 +0,0 @@ -use std::sync::Arc; - -use conduit::Result; -use database::{Database, Map}; -use ruma::{DeviceId, TransactionId, UserId}; - -pub struct Data { - userdevicetxnid_response: Arc, -} - -impl Data { - pub(super) fn new(db: &Arc) -> Self { - Self { - userdevicetxnid_response: db["userdevicetxnid_response"].clone(), - } - } - - pub(super) fn add_txnid( - &self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId, data: &[u8], - ) -> Result<()> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(device_id.map(DeviceId::as_bytes).unwrap_or_default()); - key.push(0xFF); - key.extend_from_slice(txn_id.as_bytes()); - - self.userdevicetxnid_response.insert(&key, data)?; - - Ok(()) - } - - pub(super) fn existing_txnid( - &self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId, - ) -> Result>> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(device_id.map(DeviceId::as_bytes).unwrap_or_default()); - key.push(0xFF); - key.extend_from_slice(txn_id.as_bytes()); - - // If there's no entry, this is a new transaction - self.userdevicetxnid_response.get(&key) - } -} diff --git a/src/service/transaction_ids/mod.rs b/src/service/transaction_ids/mod.rs index 78e6337f2..72f60adb1 100644 --- a/src/service/transaction_ids/mod.rs +++ b/src/service/transaction_ids/mod.rs @@ -1,35 +1,45 @@ -mod data; - use std::sync::Arc; -use conduit::Result; -use data::Data; +use conduit::{implement, Result}; +use database::{Handle, Map}; use ruma::{DeviceId, TransactionId, UserId}; pub struct Service { - pub db: Data, + db: Data, +} + +struct Data { + userdevicetxnid_response: Arc, } impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { - db: Data::new(args.db), + db: Data { + userdevicetxnid_response: args.db["userdevicetxnid_response"].clone(), + }, })) } fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } -impl Service { - pub fn add_txnid( - &self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId, data: &[u8], - ) -> Result<()> { - self.db.add_txnid(user_id, device_id, txn_id, data) - } +#[implement(Service)] +pub fn add_txnid(&self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId, data: &[u8]) { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(device_id.map(DeviceId::as_bytes).unwrap_or_default()); + key.push(0xFF); + key.extend_from_slice(txn_id.as_bytes()); - pub fn existing_txnid( - &self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId, - ) -> Result>> { - self.db.existing_txnid(user_id, device_id, txn_id) - } + self.db.userdevicetxnid_response.insert(&key, data); +} + +// If there's no entry, this is a new transaction +#[implement(Service)] +pub async fn existing_txnid( + &self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId, +) -> Result> { + let key = (user_id, device_id, txn_id); + self.db.userdevicetxnid_response.qry(&key).await } diff --git a/src/service/uiaa/data.rs b/src/service/uiaa/data.rs deleted file mode 100644 index ce071da09..000000000 --- a/src/service/uiaa/data.rs +++ /dev/null @@ -1,87 +0,0 @@ -use std::{ - collections::BTreeMap, - sync::{Arc, RwLock}, -}; - -use conduit::{Error, Result}; -use database::{Database, Map}; -use ruma::{ - api::client::{error::ErrorKind, uiaa::UiaaInfo}, - CanonicalJsonValue, DeviceId, OwnedDeviceId, OwnedUserId, UserId, -}; - -pub struct Data { - userdevicesessionid_uiaarequest: RwLock>, - userdevicesessionid_uiaainfo: Arc, -} - -impl Data { - pub(super) fn new(db: &Arc) -> Self { - Self { - userdevicesessionid_uiaarequest: RwLock::new(BTreeMap::new()), - userdevicesessionid_uiaainfo: db["userdevicesessionid_uiaainfo"].clone(), - } - } - - pub(super) fn set_uiaa_request( - &self, user_id: &UserId, device_id: &DeviceId, session: &str, request: &CanonicalJsonValue, - ) -> Result<()> { - self.userdevicesessionid_uiaarequest - .write() - .unwrap() - .insert( - (user_id.to_owned(), device_id.to_owned(), session.to_owned()), - request.to_owned(), - ); - - Ok(()) - } - - pub(super) fn get_uiaa_request( - &self, user_id: &UserId, device_id: &DeviceId, session: &str, - ) -> Option { - self.userdevicesessionid_uiaarequest - .read() - .unwrap() - .get(&(user_id.to_owned(), device_id.to_owned(), session.to_owned())) - .map(ToOwned::to_owned) - } - - pub(super) fn update_uiaa_session( - &self, user_id: &UserId, device_id: &DeviceId, session: &str, uiaainfo: Option<&UiaaInfo>, - ) -> Result<()> { - let mut userdevicesessionid = user_id.as_bytes().to_vec(); - userdevicesessionid.push(0xFF); - userdevicesessionid.extend_from_slice(device_id.as_bytes()); - userdevicesessionid.push(0xFF); - userdevicesessionid.extend_from_slice(session.as_bytes()); - - if let Some(uiaainfo) = uiaainfo { - self.userdevicesessionid_uiaainfo.insert( - &userdevicesessionid, - &serde_json::to_vec(&uiaainfo).expect("UiaaInfo::to_vec always works"), - )?; - } else { - self.userdevicesessionid_uiaainfo - .remove(&userdevicesessionid)?; - } - - Ok(()) - } - - pub(super) fn get_uiaa_session(&self, user_id: &UserId, device_id: &DeviceId, session: &str) -> Result { - let mut userdevicesessionid = user_id.as_bytes().to_vec(); - userdevicesessionid.push(0xFF); - userdevicesessionid.extend_from_slice(device_id.as_bytes()); - userdevicesessionid.push(0xFF); - userdevicesessionid.extend_from_slice(session.as_bytes()); - - serde_json::from_slice( - &self - .userdevicesessionid_uiaainfo - .get(&userdevicesessionid)? - .ok_or(Error::BadRequest(ErrorKind::forbidden(), "UIAA session does not exist."))?, - ) - .map_err(|_| Error::bad_database("UiaaInfo in userdeviceid_uiaainfo is invalid.")) - } -} diff --git a/src/service/uiaa/mod.rs b/src/service/uiaa/mod.rs index 6041bbd34..0415bfc23 100644 --- a/src/service/uiaa/mod.rs +++ b/src/service/uiaa/mod.rs @@ -1,174 +1,243 @@ -mod data; - -use std::sync::Arc; +use std::{ + collections::BTreeMap, + sync::{Arc, RwLock}, +}; -use conduit::{error, utils, utils::hash, Error, Result, Server}; -use data::Data; +use conduit::{ + err, error, implement, utils, + utils::{hash, string::EMPTY}, + Error, Result, Server, +}; +use database::{Deserialized, Map}; use ruma::{ api::client::{ error::ErrorKind, uiaa::{AuthData, AuthType, Password, UiaaInfo, UserIdentifier}, }, - CanonicalJsonValue, DeviceId, UserId, + CanonicalJsonValue, DeviceId, OwnedDeviceId, OwnedUserId, UserId, }; use crate::{globals, users, Dep}; -pub const SESSION_ID_LENGTH: usize = 32; - pub struct Service { - server: Arc, + userdevicesessionid_uiaarequest: RwLock, + db: Data, services: Services, - pub db: Data, } struct Services { + server: Arc, globals: Dep, users: Dep, } +struct Data { + userdevicesessionid_uiaainfo: Arc, +} + +type RequestMap = BTreeMap; +type RequestKey = (OwnedUserId, OwnedDeviceId, String); + +pub const SESSION_ID_LENGTH: usize = 32; + impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { - server: args.server.clone(), + userdevicesessionid_uiaarequest: RwLock::new(RequestMap::new()), + db: Data { + userdevicesessionid_uiaainfo: args.db["userdevicesessionid_uiaainfo"].clone(), + }, services: Services { + server: args.server.clone(), globals: args.depend::("globals"), users: args.depend::("users"), }, - db: Data::new(args.db), })) } fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } -impl Service { - /// Creates a new Uiaa session. Make sure the session token is unique. - pub fn create( - &self, user_id: &UserId, device_id: &DeviceId, uiaainfo: &UiaaInfo, json_body: &CanonicalJsonValue, - ) -> Result<()> { - self.db.set_uiaa_request( - user_id, - device_id, - uiaainfo.session.as_ref().expect("session should be set"), /* TODO: better session error handling (why - * is it optional in ruma?) */ - json_body, - )?; - self.db.update_uiaa_session( - user_id, - device_id, - uiaainfo.session.as_ref().expect("session should be set"), - Some(uiaainfo), - ) - } - - pub fn try_auth( - &self, user_id: &UserId, device_id: &DeviceId, auth: &AuthData, uiaainfo: &UiaaInfo, - ) -> Result<(bool, UiaaInfo)> { - let mut uiaainfo = auth.session().map_or_else( - || Ok(uiaainfo.clone()), - |session| self.db.get_uiaa_session(user_id, device_id, session), - )?; - - if uiaainfo.session.is_none() { - uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); - } +/// Creates a new Uiaa session. Make sure the session token is unique. +#[implement(Service)] +pub fn create(&self, user_id: &UserId, device_id: &DeviceId, uiaainfo: &UiaaInfo, json_body: &CanonicalJsonValue) { + // TODO: better session error handling (why is uiaainfo.session optional in + // ruma?) + self.set_uiaa_request( + user_id, + device_id, + uiaainfo.session.as_ref().expect("session should be set"), + json_body, + ); + + self.update_uiaa_session( + user_id, + device_id, + uiaainfo.session.as_ref().expect("session should be set"), + Some(uiaainfo), + ); +} - match auth { - // Find out what the user completed - AuthData::Password(Password { - identifier, - password, - #[cfg(feature = "element_hacks")] - user, - .. - }) => { - #[cfg(feature = "element_hacks")] - let username = if let Some(UserIdentifier::UserIdOrLocalpart(username)) = identifier { - username - } else if let Some(username) = user { - username - } else { - return Err(Error::BadRequest(ErrorKind::Unrecognized, "Identifier type not recognized.")); - }; - - #[cfg(not(feature = "element_hacks"))] - let Some(UserIdentifier::UserIdOrLocalpart(username)) = identifier - else { - return Err(Error::BadRequest(ErrorKind::Unrecognized, "Identifier type not recognized.")); - }; - - let user_id = UserId::parse_with_server_name(username.clone(), self.services.globals.server_name()) - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "User ID is invalid."))?; - - // Check if password is correct - if let Some(hash) = self.services.users.password_hash(&user_id)? { - let hash_matches = hash::verify_password(password, &hash).is_ok(); - if !hash_matches { - uiaainfo.auth_error = Some(ruma::api::client::error::StandardErrorBody { - kind: ErrorKind::forbidden(), - message: "Invalid username or password.".to_owned(), - }); - return Ok((false, uiaainfo)); - } - } +#[implement(Service)] +pub async fn try_auth( + &self, user_id: &UserId, device_id: &DeviceId, auth: &AuthData, uiaainfo: &UiaaInfo, +) -> Result<(bool, UiaaInfo)> { + let mut uiaainfo = if let Some(session) = auth.session() { + self.get_uiaa_session(user_id, device_id, session).await? + } else { + uiaainfo.clone() + }; + + if uiaainfo.session.is_none() { + uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); + } - // Password was correct! Let's add it to `completed` - uiaainfo.completed.push(AuthType::Password); - }, - AuthData::RegistrationToken(t) => { - if Some(t.token.trim()) == self.server.config.registration_token.as_deref() { - uiaainfo.completed.push(AuthType::RegistrationToken); - } else { + match auth { + // Find out what the user completed + AuthData::Password(Password { + identifier, + password, + #[cfg(feature = "element_hacks")] + user, + .. + }) => { + #[cfg(feature = "element_hacks")] + let username = if let Some(UserIdentifier::UserIdOrLocalpart(username)) = identifier { + username + } else if let Some(username) = user { + username + } else { + return Err(Error::BadRequest(ErrorKind::Unrecognized, "Identifier type not recognized.")); + }; + + #[cfg(not(feature = "element_hacks"))] + let Some(UserIdentifier::UserIdOrLocalpart(username)) = identifier + else { + return Err(Error::BadRequest(ErrorKind::Unrecognized, "Identifier type not recognized.")); + }; + + let user_id = UserId::parse_with_server_name(username.clone(), self.services.globals.server_name()) + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "User ID is invalid."))?; + + // Check if password is correct + if let Ok(hash) = self.services.users.password_hash(&user_id).await { + let hash_matches = hash::verify_password(password, &hash).is_ok(); + if !hash_matches { uiaainfo.auth_error = Some(ruma::api::client::error::StandardErrorBody { kind: ErrorKind::forbidden(), - message: "Invalid registration token.".to_owned(), + message: "Invalid username or password.".to_owned(), }); return Ok((false, uiaainfo)); } - }, - AuthData::Dummy(_) => { - uiaainfo.completed.push(AuthType::Dummy); - }, - k => error!("type not supported: {:?}", k), - } + } - // Check if a flow now succeeds - let mut completed = false; - 'flows: for flow in &mut uiaainfo.flows { - for stage in &flow.stages { - if !uiaainfo.completed.contains(stage) { - continue 'flows; - } + // Password was correct! Let's add it to `completed` + uiaainfo.completed.push(AuthType::Password); + }, + AuthData::RegistrationToken(t) => { + if Some(t.token.trim()) == self.services.server.config.registration_token.as_deref() { + uiaainfo.completed.push(AuthType::RegistrationToken); + } else { + uiaainfo.auth_error = Some(ruma::api::client::error::StandardErrorBody { + kind: ErrorKind::forbidden(), + message: "Invalid registration token.".to_owned(), + }); + return Ok((false, uiaainfo)); } - // We didn't break, so this flow succeeded! - completed = true; - } + }, + AuthData::Dummy(_) => { + uiaainfo.completed.push(AuthType::Dummy); + }, + k => error!("type not supported: {:?}", k), + } - if !completed { - self.db.update_uiaa_session( - user_id, - device_id, - uiaainfo.session.as_ref().expect("session is always set"), - Some(&uiaainfo), - )?; - return Ok((false, uiaainfo)); + // Check if a flow now succeeds + let mut completed = false; + 'flows: for flow in &mut uiaainfo.flows { + for stage in &flow.stages { + if !uiaainfo.completed.contains(stage) { + continue 'flows; + } } + // We didn't break, so this flow succeeded! + completed = true; + } - // UIAA was successful! Remove this session and return true - self.db.update_uiaa_session( + if !completed { + self.update_uiaa_session( user_id, device_id, uiaainfo.session.as_ref().expect("session is always set"), - None, - )?; - Ok((true, uiaainfo)) + Some(&uiaainfo), + ); + + return Ok((false, uiaainfo)); } - #[must_use] - pub fn get_uiaa_request( - &self, user_id: &UserId, device_id: &DeviceId, session: &str, - ) -> Option { - self.db.get_uiaa_request(user_id, device_id, session) + // UIAA was successful! Remove this session and return true + self.update_uiaa_session( + user_id, + device_id, + uiaainfo.session.as_ref().expect("session is always set"), + None, + ); + + Ok((true, uiaainfo)) +} + +#[implement(Service)] +fn set_uiaa_request(&self, user_id: &UserId, device_id: &DeviceId, session: &str, request: &CanonicalJsonValue) { + let key = (user_id.to_owned(), device_id.to_owned(), session.to_owned()); + self.userdevicesessionid_uiaarequest + .write() + .expect("locked for writing") + .insert(key, request.to_owned()); +} + +#[implement(Service)] +pub fn get_uiaa_request( + &self, user_id: &UserId, device_id: Option<&DeviceId>, session: &str, +) -> Option { + let key = ( + user_id.to_owned(), + device_id.unwrap_or_else(|| EMPTY.into()).to_owned(), + session.to_owned(), + ); + + self.userdevicesessionid_uiaarequest + .read() + .expect("locked for reading") + .get(&key) + .cloned() +} + +#[implement(Service)] +fn update_uiaa_session(&self, user_id: &UserId, device_id: &DeviceId, session: &str, uiaainfo: Option<&UiaaInfo>) { + let mut userdevicesessionid = user_id.as_bytes().to_vec(); + userdevicesessionid.push(0xFF); + userdevicesessionid.extend_from_slice(device_id.as_bytes()); + userdevicesessionid.push(0xFF); + userdevicesessionid.extend_from_slice(session.as_bytes()); + + if let Some(uiaainfo) = uiaainfo { + self.db.userdevicesessionid_uiaainfo.insert( + &userdevicesessionid, + &serde_json::to_vec(&uiaainfo).expect("UiaaInfo::to_vec always works"), + ); + } else { + self.db + .userdevicesessionid_uiaainfo + .remove(&userdevicesessionid); } } + +#[implement(Service)] +async fn get_uiaa_session(&self, user_id: &UserId, device_id: &DeviceId, session: &str) -> Result { + let key = (user_id, device_id, session); + self.db + .userdevicesessionid_uiaainfo + .qry(&key) + .await + .deserialized() + .map_err(|_| err!(Request(Forbidden("UIAA session does not exist.")))) +} diff --git a/src/service/updates/mod.rs b/src/service/updates/mod.rs index 3c69b2430..4e16e22b0 100644 --- a/src/service/updates/mod.rs +++ b/src/service/updates/mod.rs @@ -1,19 +1,22 @@ use std::{sync::Arc, time::Duration}; use async_trait::async_trait; -use conduit::{debug, err, info, utils, warn, Error, Result}; -use database::Map; +use conduit::{debug, info, warn, Result}; +use database::{Deserialized, Map}; use ruma::events::room::message::RoomMessageEventContent; use serde::Deserialize; -use tokio::{sync::Notify, time::interval}; +use tokio::{ + sync::Notify, + time::{interval, MissedTickBehavior}, +}; use crate::{admin, client, globals, Dep}; pub struct Service { - services: Services, - db: Arc, - interrupt: Notify, interval: Duration, + interrupt: Notify, + db: Arc, + services: Services, } struct Services { @@ -22,12 +25,12 @@ struct Services { globals: Dep, } -#[derive(Deserialize)] +#[derive(Debug, Deserialize)] struct CheckForUpdatesResponse { updates: Vec, } -#[derive(Deserialize)] +#[derive(Debug, Deserialize)] struct CheckForUpdatesResponseEntry { id: u64, date: String, @@ -42,33 +45,38 @@ const LAST_CHECK_FOR_UPDATES_COUNT: &[u8] = b"u"; impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { + interval: Duration::from_secs(CHECK_FOR_UPDATES_INTERVAL), + interrupt: Notify::new(), + db: args.db["global"].clone(), services: Services { globals: args.depend::("globals"), admin: args.depend::("admin"), client: args.depend::("client"), }, - db: args.db["global"].clone(), - interrupt: Notify::new(), - interval: Duration::from_secs(CHECK_FOR_UPDATES_INTERVAL), })) } + #[tracing::instrument(skip_all, name = "updates", level = "trace")] async fn worker(self: Arc) -> Result<()> { if !self.services.globals.allow_check_for_updates() { debug!("Disabling update check"); return Ok(()); } + let mut i = interval(self.interval); + i.set_missed_tick_behavior(MissedTickBehavior::Delay); loop { tokio::select! { - () = self.interrupt.notified() => return Ok(()), + () = self.interrupt.notified() => break, _ = i.tick() => (), } - if let Err(e) = self.handle_updates().await { + if let Err(e) = self.check().await { warn!(%e, "Failed to check for updates"); } } + + Ok(()) } fn interrupt(&self) { self.interrupt.notify_waiters(); } @@ -77,52 +85,52 @@ impl crate::Service for Service { } impl Service { - #[tracing::instrument(skip_all)] - async fn handle_updates(&self) -> Result<()> { + #[tracing::instrument(skip_all, level = "trace")] + async fn check(&self) -> Result<()> { let response = self .services .client .default .get(CHECK_FOR_UPDATES_URL) .send() + .await? + .text() .await?; - let response = serde_json::from_str::(&response.text().await?) - .map_err(|e| err!("Bad check for updates response: {e}"))?; - - let mut last_update_id = self.last_check_for_updates_id()?; - for update in response.updates { - last_update_id = last_update_id.max(update.id); - if update.id > self.last_check_for_updates_id()? { - info!("{:#}", update.message); - self.services - .admin - .send_message(RoomMessageEventContent::text_markdown(format!( - "### the following is a message from the conduwuit puppy\n\nit was sent on `{}`:\n\n@room: {}", - update.date, update.message - ))) - .await; + let response = serde_json::from_str::(&response)?; + for update in &response.updates { + if update.id > self.last_check_for_updates_id().await { + self.handle(update).await; + self.update_check_for_updates_id(update.id); } } - self.update_check_for_updates_id(last_update_id)?; Ok(()) } + async fn handle(&self, update: &CheckForUpdatesResponseEntry) { + info!("{} {:#}", update.date, update.message); + self.services + .admin + .send_message(RoomMessageEventContent::text_markdown(format!( + "### the following is a message from the conduwuit puppy\n\nit was sent on `{}`:\n\n@room: {}", + update.date, update.message + ))) + .await + .ok(); + } + #[inline] - pub fn update_check_for_updates_id(&self, id: u64) -> Result<()> { + pub fn update_check_for_updates_id(&self, id: u64) { self.db - .insert(LAST_CHECK_FOR_UPDATES_COUNT, &id.to_be_bytes())?; - - Ok(()) + .insert(LAST_CHECK_FOR_UPDATES_COUNT, &id.to_be_bytes()); } - pub fn last_check_for_updates_id(&self) -> Result { + pub async fn last_check_for_updates_id(&self) -> u64 { self.db - .get(LAST_CHECK_FOR_UPDATES_COUNT)? - .map_or(Ok(0_u64), |bytes| { - utils::u64_from_bytes(&bytes) - .map_err(|_| Error::bad_database("last check for updates count has invalid bytes.")) - }) + .qry(LAST_CHECK_FOR_UPDATES_COUNT) + .await + .deserialized() + .unwrap_or(0_u64) } } diff --git a/src/service/users/data.rs b/src/service/users/data.rs deleted file mode 100644 index 70ff12e3f..000000000 --- a/src/service/users/data.rs +++ /dev/null @@ -1,1098 +0,0 @@ -use std::{collections::BTreeMap, mem::size_of, sync::Arc}; - -use conduit::{debug_info, err, utils, warn, Err, Error, Result, Server}; -use database::Map; -use ruma::{ - api::client::{device::Device, error::ErrorKind, filter::FilterDefinition}, - encryption::{CrossSigningKey, DeviceKeys, OneTimeKey}, - events::{AnyToDeviceEvent, StateEventType}, - serde::Raw, - uint, DeviceId, DeviceKeyAlgorithm, DeviceKeyId, MilliSecondsSinceUnixEpoch, OwnedDeviceId, OwnedDeviceKeyId, - OwnedMxcUri, OwnedUserId, UInt, UserId, -}; - -use crate::{globals, rooms, users::clean_signatures, Dep}; - -pub struct Data { - keychangeid_userid: Arc, - keyid_key: Arc, - onetimekeyid_onetimekeys: Arc, - openidtoken_expiresatuserid: Arc, - todeviceid_events: Arc, - token_userdeviceid: Arc, - userdeviceid_metadata: Arc, - userdeviceid_token: Arc, - userfilterid_filter: Arc, - userid_avatarurl: Arc, - userid_blurhash: Arc, - userid_devicelistversion: Arc, - userid_displayname: Arc, - userid_lastonetimekeyupdate: Arc, - userid_masterkeyid: Arc, - userid_password: Arc, - userid_selfsigningkeyid: Arc, - userid_usersigningkeyid: Arc, - useridprofilekey_value: Arc, - services: Services, -} - -struct Services { - server: Arc, - globals: Dep, - state_cache: Dep, - state_accessor: Dep, -} - -impl Data { - pub(super) fn new(args: &crate::Args<'_>) -> Self { - let db = &args.db; - Self { - keychangeid_userid: db["keychangeid_userid"].clone(), - keyid_key: db["keyid_key"].clone(), - onetimekeyid_onetimekeys: db["onetimekeyid_onetimekeys"].clone(), - openidtoken_expiresatuserid: db["openidtoken_expiresatuserid"].clone(), - todeviceid_events: db["todeviceid_events"].clone(), - token_userdeviceid: db["token_userdeviceid"].clone(), - userdeviceid_metadata: db["userdeviceid_metadata"].clone(), - userdeviceid_token: db["userdeviceid_token"].clone(), - userfilterid_filter: db["userfilterid_filter"].clone(), - userid_avatarurl: db["userid_avatarurl"].clone(), - userid_blurhash: db["userid_blurhash"].clone(), - userid_devicelistversion: db["userid_devicelistversion"].clone(), - userid_displayname: db["userid_displayname"].clone(), - userid_lastonetimekeyupdate: db["userid_lastonetimekeyupdate"].clone(), - userid_masterkeyid: db["userid_masterkeyid"].clone(), - userid_password: db["userid_password"].clone(), - userid_selfsigningkeyid: db["userid_selfsigningkeyid"].clone(), - userid_usersigningkeyid: db["userid_usersigningkeyid"].clone(), - useridprofilekey_value: db["useridprofilekey_value"].clone(), - services: Services { - server: args.server.clone(), - globals: args.depend::("globals"), - state_cache: args.depend::("rooms::state_cache"), - state_accessor: args.depend::("rooms::state_accessor"), - }, - } - } - - /// Check if a user has an account on this homeserver. - #[inline] - pub(super) fn exists(&self, user_id: &UserId) -> Result { - Ok(self.userid_password.get(user_id.as_bytes())?.is_some()) - } - - /// Check if account is deactivated - pub(super) fn is_deactivated(&self, user_id: &UserId) -> Result { - Ok(self - .userid_password - .get(user_id.as_bytes())? - .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "User does not exist."))? - .is_empty()) - } - - /// Returns the number of users registered on this server. - #[inline] - pub(super) fn count(&self) -> Result { Ok(self.userid_password.iter().count()) } - - /// Find out which user an access token belongs to. - pub(super) fn find_from_token(&self, token: &str) -> Result> { - self.token_userdeviceid - .get(token.as_bytes())? - .map_or(Ok(None), |bytes| { - let mut parts = bytes.split(|&b| b == 0xFF); - let user_bytes = parts - .next() - .ok_or_else(|| err!(Database("User ID in token_userdeviceid is invalid.")))?; - let device_bytes = parts - .next() - .ok_or_else(|| err!(Database("Device ID in token_userdeviceid is invalid.")))?; - - Ok(Some(( - UserId::parse( - utils::string_from_bytes(user_bytes) - .map_err(|e| err!(Database("User ID in token_userdeviceid is invalid unicode. {e}")))?, - ) - .map_err(|e| err!(Database("User ID in token_userdeviceid is invalid. {e}")))?, - utils::string_from_bytes(device_bytes) - .map_err(|e| err!(Database("Device ID in token_userdeviceid is invalid. {e}")))?, - ))) - }) - } - - /// Returns an iterator over all users on this homeserver. - pub fn iter<'a>(&'a self) -> Box> + 'a> { - Box::new(self.userid_password.iter().map(|(bytes, _)| { - UserId::parse( - utils::string_from_bytes(&bytes) - .map_err(|e| err!(Database("User ID in userid_password is invalid unicode. {e}")))?, - ) - .map_err(|e| err!(Database("User ID in userid_password is invalid. {e}"))) - })) - } - - /// Returns a list of local users as list of usernames. - /// - /// A user account is considered `local` if the length of it's password is - /// greater then zero. - pub(super) fn list_local_users(&self) -> Result> { - let users: Vec = self - .userid_password - .iter() - .filter_map(|(username, pw)| get_username_with_valid_password(&username, &pw)) - .collect(); - Ok(users) - } - - /// Returns the password hash for the given user. - pub(super) fn password_hash(&self, user_id: &UserId) -> Result> { - self.userid_password - .get(user_id.as_bytes())? - .map_or(Ok(None), |bytes| { - Ok(Some(utils::string_from_bytes(&bytes).map_err(|_| { - Error::bad_database("Password hash in db is not valid string.") - })?)) - }) - } - - /// Hash and set the user's password to the Argon2 hash - pub(super) fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { - if let Some(password) = password { - if let Ok(hash) = utils::hash::password(password) { - self.userid_password - .insert(user_id.as_bytes(), hash.as_bytes())?; - Ok(()) - } else { - Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Password does not meet the requirements.", - )) - } - } else { - self.userid_password.insert(user_id.as_bytes(), b"")?; - Ok(()) - } - } - - /// Returns the displayname of a user on this homeserver. - pub(super) fn displayname(&self, user_id: &UserId) -> Result> { - self.userid_displayname - .get(user_id.as_bytes())? - .map_or(Ok(None), |bytes| { - Ok(Some( - utils::string_from_bytes(&bytes) - .map_err(|e| err!(Database("Displayname in db is invalid. {e}")))?, - )) - }) - } - - /// Sets a new displayname or removes it if displayname is None. You still - /// need to nofify all rooms of this change. - pub(super) fn set_displayname(&self, user_id: &UserId, displayname: Option) -> Result<()> { - if let Some(displayname) = displayname { - self.userid_displayname - .insert(user_id.as_bytes(), displayname.as_bytes())?; - } else { - self.userid_displayname.remove(user_id.as_bytes())?; - } - - Ok(()) - } - - /// Get the `avatar_url` of a user. - pub(super) fn avatar_url(&self, user_id: &UserId) -> Result> { - self.userid_avatarurl - .get(user_id.as_bytes())? - .map(|bytes| { - let s_bytes = utils::string_from_bytes(&bytes) - .map_err(|e| err!(Database(warn!("Avatar URL in db is invalid: {e}"))))?; - let mxc_uri: OwnedMxcUri = s_bytes.into(); - Ok(mxc_uri) - }) - .transpose() - } - - /// Sets a new avatar_url or removes it if avatar_url is None. - pub(super) fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option) -> Result<()> { - if let Some(avatar_url) = avatar_url { - self.userid_avatarurl - .insert(user_id.as_bytes(), avatar_url.to_string().as_bytes())?; - } else { - self.userid_avatarurl.remove(user_id.as_bytes())?; - } - - Ok(()) - } - - /// Get the blurhash of a user. - pub(super) fn blurhash(&self, user_id: &UserId) -> Result> { - self.userid_blurhash - .get(user_id.as_bytes())? - .map(|bytes| { - utils::string_from_bytes(&bytes).map_err(|e| err!(Database("Avatar URL in db is invalid. {e}"))) - }) - .transpose() - } - - /// Gets a specific user profile key - pub(super) fn profile_key(&self, user_id: &UserId, profile_key: &str) -> Result> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(profile_key.as_bytes()); - - self.useridprofilekey_value - .get(&key)? - .map_or(Ok(None), |bytes| Ok(Some(serde_json::from_slice(&bytes).unwrap()))) - } - - /// Gets all the user's profile keys and values in an iterator - pub(super) fn all_profile_keys<'a>( - &'a self, user_id: &UserId, - ) -> Box> + 'a + Send> { - let prefix = user_id.as_bytes().to_vec(); - - Box::new( - self.useridprofilekey_value - .scan_prefix(prefix) - .map(|(key, value)| { - let profile_key_name = utils::string_from_bytes( - key.rsplit(|&b| b == 0xFF) - .next() - .ok_or_else(|| err!(Database("Profile key in db is invalid")))?, - ) - .map_err(|e| err!(Database("Profile key in db is invalid. {e}")))?; - - let profile_key_value = serde_json::from_slice(&value) - .map_err(|e| err!(Database("Profile key in db is invalid. {e}")))?; - - Ok((profile_key_name, profile_key_value)) - }), - ) - } - - /// Sets a new profile key value, removes the key if value is None - pub(super) fn set_profile_key( - &self, user_id: &UserId, profile_key: &str, profile_key_value: Option, - ) -> Result<()> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(profile_key.as_bytes()); - - // TODO: insert to the stable MSC4175 key when it's stable - if let Some(value) = profile_key_value { - let value = serde_json::to_vec(&value).unwrap(); - - self.useridprofilekey_value.insert(&key, &value) - } else { - self.useridprofilekey_value.remove(&key) - } - } - - /// Get the timezone of a user. - pub(super) fn timezone(&self, user_id: &UserId) -> Result> { - // first check the unstable prefix - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(b"us.cloke.msc4175.tz"); - - let value = self - .useridprofilekey_value - .get(&key)? - .map(|bytes| utils::string_from_bytes(&bytes).map_err(|e| err!(Database("Timezone in db is invalid. {e}")))) - .transpose() - .unwrap(); - - // TODO: transparently migrate unstable key usage to the stable key once MSC4133 - // and MSC4175 are stable, likely a remove/insert in this block - if value.is_none() || value.as_ref().is_some_and(String::is_empty) { - // check the stable prefix - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(b"m.tz"); - - return self - .useridprofilekey_value - .get(&key)? - .map(|bytes| { - utils::string_from_bytes(&bytes).map_err(|e| err!(Database("Timezone in db is invalid. {e}"))) - }) - .transpose(); - } - - Ok(value) - } - - /// Sets a new timezone or removes it if timezone is None. - pub(super) fn set_timezone(&self, user_id: &UserId, timezone: Option) -> Result<()> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(b"us.cloke.msc4175.tz"); - - // TODO: insert to the stable MSC4175 key when it's stable - if let Some(timezone) = timezone { - self.useridprofilekey_value - .insert(&key, timezone.as_bytes())?; - } else { - self.useridprofilekey_value.remove(&key)?; - } - - Ok(()) - } - - /// Sets a new avatar_url or removes it if avatar_url is None. - pub(super) fn set_blurhash(&self, user_id: &UserId, blurhash: Option) -> Result<()> { - if let Some(blurhash) = blurhash { - self.userid_blurhash - .insert(user_id.as_bytes(), blurhash.as_bytes())?; - } else { - self.userid_blurhash.remove(user_id.as_bytes())?; - } - - Ok(()) - } - - /// Adds a new device to a user. - pub(super) fn create_device( - &self, user_id: &UserId, device_id: &DeviceId, token: &str, initial_device_display_name: Option, - client_ip: Option, - ) -> Result<()> { - // This method should never be called for nonexistent users. We shouldn't assert - // though... - if !self.exists(user_id)? { - warn!("Called create_device for non-existent user {} in database", user_id); - return Err(Error::BadRequest(ErrorKind::InvalidParam, "User does not exist.")); - } - - let mut userdeviceid = user_id.as_bytes().to_vec(); - userdeviceid.push(0xFF); - userdeviceid.extend_from_slice(device_id.as_bytes()); - - self.userid_devicelistversion - .increment(user_id.as_bytes())?; - - self.userdeviceid_metadata.insert( - &userdeviceid, - &serde_json::to_vec(&Device { - device_id: device_id.into(), - display_name: initial_device_display_name, - last_seen_ip: client_ip, - last_seen_ts: Some(MilliSecondsSinceUnixEpoch::now()), - }) - .expect("Device::to_string never fails."), - )?; - - self.set_token(user_id, device_id, token)?; - - Ok(()) - } - - /// Removes a device from a user. - pub(super) fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { - let mut userdeviceid = user_id.as_bytes().to_vec(); - userdeviceid.push(0xFF); - userdeviceid.extend_from_slice(device_id.as_bytes()); - - // Remove tokens - if let Some(old_token) = self.userdeviceid_token.get(&userdeviceid)? { - self.userdeviceid_token.remove(&userdeviceid)?; - self.token_userdeviceid.remove(&old_token)?; - } - - // Remove todevice events - let mut prefix = userdeviceid.clone(); - prefix.push(0xFF); - - for (key, _) in self.todeviceid_events.scan_prefix(prefix) { - self.todeviceid_events.remove(&key)?; - } - - // TODO: Remove onetimekeys - - self.userid_devicelistversion - .increment(user_id.as_bytes())?; - - self.userdeviceid_metadata.remove(&userdeviceid)?; - - Ok(()) - } - - /// Returns an iterator over all device ids of this user. - pub(super) fn all_device_ids<'a>( - &'a self, user_id: &UserId, - ) -> Box> + 'a> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - // All devices have metadata - Box::new( - self.userdeviceid_metadata - .scan_prefix(prefix) - .map(|(bytes, _)| { - Ok(utils::string_from_bytes( - bytes - .rsplit(|&b| b == 0xFF) - .next() - .ok_or_else(|| err!(Database("UserDevice ID in db is invalid.")))?, - ) - .map_err(|e| err!(Database("Device ID in userdeviceid_metadata is invalid. {e}")))? - .into()) - }), - ) - } - - /// Replaces the access token of one device. - pub(super) fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()> { - let mut userdeviceid = user_id.as_bytes().to_vec(); - userdeviceid.push(0xFF); - userdeviceid.extend_from_slice(device_id.as_bytes()); - - // should not be None, but we shouldn't assert either lol... - if self.userdeviceid_metadata.get(&userdeviceid)?.is_none() { - return Err!(Database(error!( - "User {user_id:?} does not exist or device ID {device_id:?} has no metadata." - ))); - } - - // Remove old token - if let Some(old_token) = self.userdeviceid_token.get(&userdeviceid)? { - self.token_userdeviceid.remove(&old_token)?; - // It will be removed from userdeviceid_token by the insert later - } - - // Assign token to user device combination - self.userdeviceid_token - .insert(&userdeviceid, token.as_bytes())?; - self.token_userdeviceid - .insert(token.as_bytes(), &userdeviceid)?; - - Ok(()) - } - - pub(super) fn add_one_time_key( - &self, user_id: &UserId, device_id: &DeviceId, one_time_key_key: &DeviceKeyId, - one_time_key_value: &Raw, - ) -> Result<()> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(device_id.as_bytes()); - - // All devices have metadata - // Only existing devices should be able to call this, but we shouldn't assert - // either... - if self.userdeviceid_metadata.get(&key)?.is_none() { - return Err!(Database(error!( - "User {user_id:?} does not exist or device ID {device_id:?} has no metadata." - ))); - } - - key.push(0xFF); - // TODO: Use DeviceKeyId::to_string when it's available (and update everything, - // because there are no wrapping quotation marks anymore) - key.extend_from_slice( - serde_json::to_string(one_time_key_key) - .expect("DeviceKeyId::to_string always works") - .as_bytes(), - ); - - self.onetimekeyid_onetimekeys.insert( - &key, - &serde_json::to_vec(&one_time_key_value).expect("OneTimeKey::to_vec always works"), - )?; - - self.userid_lastonetimekeyupdate - .insert(user_id.as_bytes(), &self.services.globals.next_count()?.to_be_bytes())?; - - Ok(()) - } - - pub(super) fn last_one_time_keys_update(&self, user_id: &UserId) -> Result { - self.userid_lastonetimekeyupdate - .get(user_id.as_bytes())? - .map_or(Ok(0), |bytes| { - utils::u64_from_bytes(&bytes) - .map_err(|e| err!(Database("Count in roomid_lastroomactiveupdate is invalid. {e}"))) - }) - } - - pub(super) fn take_one_time_key( - &self, user_id: &UserId, device_id: &DeviceId, key_algorithm: &DeviceKeyAlgorithm, - ) -> Result)>> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - prefix.extend_from_slice(device_id.as_bytes()); - prefix.push(0xFF); - prefix.push(b'"'); // Annoying quotation mark - prefix.extend_from_slice(key_algorithm.as_ref().as_bytes()); - prefix.push(b':'); - - self.userid_lastonetimekeyupdate - .insert(user_id.as_bytes(), &self.services.globals.next_count()?.to_be_bytes())?; - - self.onetimekeyid_onetimekeys - .scan_prefix(prefix) - .next() - .map(|(key, value)| { - self.onetimekeyid_onetimekeys.remove(&key)?; - - Ok(( - serde_json::from_slice( - key.rsplit(|&b| b == 0xFF) - .next() - .ok_or_else(|| err!(Database("OneTimeKeyId in db is invalid.")))?, - ) - .map_err(|e| err!(Database("OneTimeKeyId in db is invalid. {e}")))?, - serde_json::from_slice(&value).map_err(|e| err!(Database("OneTimeKeys in db are invalid. {e}")))?, - )) - }) - .transpose() - } - - pub(super) fn count_one_time_keys( - &self, user_id: &UserId, device_id: &DeviceId, - ) -> Result> { - let mut userdeviceid = user_id.as_bytes().to_vec(); - userdeviceid.push(0xFF); - userdeviceid.extend_from_slice(device_id.as_bytes()); - - let mut counts = BTreeMap::new(); - - for algorithm in self - .onetimekeyid_onetimekeys - .scan_prefix(userdeviceid) - .map(|(bytes, _)| { - Ok::<_, Error>( - serde_json::from_slice::( - bytes - .rsplit(|&b| b == 0xFF) - .next() - .ok_or_else(|| err!(Database("OneTimeKey ID in db is invalid.")))?, - ) - .map_err(|e| err!(Database("DeviceKeyId in db is invalid. {e}")))? - .algorithm(), - ) - }) { - let count: &mut UInt = counts.entry(algorithm?).or_default(); - *count = count.saturating_add(uint!(1)); - } - - Ok(counts) - } - - pub(super) fn add_device_keys( - &self, user_id: &UserId, device_id: &DeviceId, device_keys: &Raw, - ) -> Result<()> { - let mut userdeviceid = user_id.as_bytes().to_vec(); - userdeviceid.push(0xFF); - userdeviceid.extend_from_slice(device_id.as_bytes()); - - self.keyid_key.insert( - &userdeviceid, - &serde_json::to_vec(&device_keys).expect("DeviceKeys::to_vec always works"), - )?; - - self.mark_device_key_update(user_id)?; - - Ok(()) - } - - pub(super) fn add_cross_signing_keys( - &self, user_id: &UserId, master_key: &Raw, self_signing_key: &Option>, - user_signing_key: &Option>, notify: bool, - ) -> Result<()> { - // TODO: Check signatures - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - - let (master_key_key, _) = Self::parse_master_key(user_id, master_key)?; - - self.keyid_key - .insert(&master_key_key, master_key.json().get().as_bytes())?; - - self.userid_masterkeyid - .insert(user_id.as_bytes(), &master_key_key)?; - - // Self-signing key - if let Some(self_signing_key) = self_signing_key { - let mut self_signing_key_ids = self_signing_key - .deserialize() - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid self signing key"))? - .keys - .into_values(); - - let self_signing_key_id = self_signing_key_ids - .next() - .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Self signing key contained no key."))?; - - if self_signing_key_ids.next().is_some() { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Self signing key contained more than one key.", - )); - } - - let mut self_signing_key_key = prefix.clone(); - self_signing_key_key.extend_from_slice(self_signing_key_id.as_bytes()); - - self.keyid_key - .insert(&self_signing_key_key, self_signing_key.json().get().as_bytes())?; - - self.userid_selfsigningkeyid - .insert(user_id.as_bytes(), &self_signing_key_key)?; - } - - // User-signing key - if let Some(user_signing_key) = user_signing_key { - let mut user_signing_key_ids = user_signing_key - .deserialize() - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid user signing key"))? - .keys - .into_values(); - - let user_signing_key_id = user_signing_key_ids - .next() - .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "User signing key contained no key."))?; - - if user_signing_key_ids.next().is_some() { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "User signing key contained more than one key.", - )); - } - - let mut user_signing_key_key = prefix; - user_signing_key_key.extend_from_slice(user_signing_key_id.as_bytes()); - - self.keyid_key - .insert(&user_signing_key_key, user_signing_key.json().get().as_bytes())?; - - self.userid_usersigningkeyid - .insert(user_id.as_bytes(), &user_signing_key_key)?; - } - - if notify { - self.mark_device_key_update(user_id)?; - } - - Ok(()) - } - - pub(super) fn sign_key( - &self, target_id: &UserId, key_id: &str, signature: (String, String), sender_id: &UserId, - ) -> Result<()> { - let mut key = target_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(key_id.as_bytes()); - - let mut cross_signing_key: serde_json::Value = serde_json::from_slice( - &self - .keyid_key - .get(&key)? - .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Tried to sign nonexistent key."))?, - ) - .map_err(|e| err!(Database("key in keyid_key is invalid. {e}")))?; - - let signatures = cross_signing_key - .get_mut("signatures") - .ok_or_else(|| err!(Database("key in keyid_key has no signatures field.")))? - .as_object_mut() - .ok_or_else(|| err!(Database("key in keyid_key has invalid signatures field.")))? - .entry(sender_id.to_string()) - .or_insert_with(|| serde_json::Map::new().into()); - - signatures - .as_object_mut() - .ok_or_else(|| err!(Database("signatures in keyid_key for a user is invalid.")))? - .insert(signature.0, signature.1.into()); - - self.keyid_key.insert( - &key, - &serde_json::to_vec(&cross_signing_key).expect("CrossSigningKey::to_vec always works"), - )?; - - self.mark_device_key_update(target_id)?; - - Ok(()) - } - - pub(super) fn keys_changed<'a>( - &'a self, user_or_room_id: &str, from: u64, to: Option, - ) -> Box> + 'a> { - let mut prefix = user_or_room_id.as_bytes().to_vec(); - prefix.push(0xFF); - - let mut start = prefix.clone(); - start.extend_from_slice(&(from.saturating_add(1)).to_be_bytes()); - - let to = to.unwrap_or(u64::MAX); - - Box::new( - self.keychangeid_userid - .iter_from(&start, false) - .take_while(move |(k, _)| { - k.starts_with(&prefix) - && if let Some(current) = k.splitn(2, |&b| b == 0xFF).nth(1) { - if let Ok(c) = utils::u64_from_bytes(current) { - c <= to - } else { - warn!("BadDatabase: Could not parse keychangeid_userid bytes"); - false - } - } else { - warn!("BadDatabase: Could not parse keychangeid_userid"); - false - } - }) - .map(|(_, bytes)| { - UserId::parse( - utils::string_from_bytes(&bytes).map_err(|_| { - Error::bad_database("User ID in devicekeychangeid_userid is invalid unicode.") - })?, - ) - .map_err(|e| err!(Database("User ID in devicekeychangeid_userid is invalid. {e}"))) - }), - ) - } - - pub(super) fn mark_device_key_update(&self, user_id: &UserId) -> Result<()> { - let count = self.services.globals.next_count()?.to_be_bytes(); - for room_id in self - .services - .state_cache - .rooms_joined(user_id) - .filter_map(Result::ok) - { - // Don't send key updates to unencrypted rooms - if self - .services - .state_accessor - .room_state_get(&room_id, &StateEventType::RoomEncryption, "")? - .is_none() - { - continue; - } - - let mut key = room_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(&count); - - self.keychangeid_userid.insert(&key, user_id.as_bytes())?; - } - - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(&count); - self.keychangeid_userid.insert(&key, user_id.as_bytes())?; - - Ok(()) - } - - pub(super) fn get_device_keys(&self, user_id: &UserId, device_id: &DeviceId) -> Result>> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(device_id.as_bytes()); - - self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| { - Ok(Some( - serde_json::from_slice(&bytes).map_err(|e| err!(Database("DeviceKeys in db are invalid. {e}")))?, - )) - }) - } - - pub(super) fn parse_master_key( - user_id: &UserId, master_key: &Raw, - ) -> Result<(Vec, CrossSigningKey)> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - - let master_key = master_key - .deserialize() - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid master key"))?; - let mut master_key_ids = master_key.keys.values(); - let master_key_id = master_key_ids - .next() - .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Master key contained no key."))?; - if master_key_ids.next().is_some() { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Master key contained more than one key.", - )); - } - let mut master_key_key = prefix.clone(); - master_key_key.extend_from_slice(master_key_id.as_bytes()); - Ok((master_key_key, master_key)) - } - - pub(super) fn get_key( - &self, key: &[u8], sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool, - ) -> Result>> { - self.keyid_key.get(key)?.map_or(Ok(None), |bytes| { - let mut cross_signing_key = serde_json::from_slice::(&bytes) - .map_err(|e| err!(Database("CrossSigningKey in db is invalid. {e}")))?; - clean_signatures(&mut cross_signing_key, sender_user, user_id, allowed_signatures)?; - - Ok(Some(Raw::from_json( - serde_json::value::to_raw_value(&cross_signing_key).expect("Value to RawValue serialization"), - ))) - }) - } - - pub(super) fn get_master_key( - &self, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool, - ) -> Result>> { - self.userid_masterkeyid - .get(user_id.as_bytes())? - .map_or(Ok(None), |key| self.get_key(&key, sender_user, user_id, allowed_signatures)) - } - - pub(super) fn get_self_signing_key( - &self, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool, - ) -> Result>> { - self.userid_selfsigningkeyid - .get(user_id.as_bytes())? - .map_or(Ok(None), |key| self.get_key(&key, sender_user, user_id, allowed_signatures)) - } - - pub(super) fn get_user_signing_key(&self, user_id: &UserId) -> Result>> { - self.userid_usersigningkeyid - .get(user_id.as_bytes())? - .map_or(Ok(None), |key| { - self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| { - Ok(Some( - serde_json::from_slice(&bytes) - .map_err(|e| err!(Database("CrossSigningKey in db is invalid. {e}")))?, - )) - }) - }) - } - - pub(super) fn add_to_device_event( - &self, sender: &UserId, target_user_id: &UserId, target_device_id: &DeviceId, event_type: &str, - content: serde_json::Value, - ) -> Result<()> { - let mut key = target_user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(target_device_id.as_bytes()); - key.push(0xFF); - key.extend_from_slice(&self.services.globals.next_count()?.to_be_bytes()); - - let mut json = serde_json::Map::new(); - json.insert("type".to_owned(), event_type.to_owned().into()); - json.insert("sender".to_owned(), sender.to_string().into()); - json.insert("content".to_owned(), content); - - let value = serde_json::to_vec(&json).expect("Map::to_vec always works"); - - self.todeviceid_events.insert(&key, &value)?; - - Ok(()) - } - - pub(super) fn get_to_device_events( - &self, user_id: &UserId, device_id: &DeviceId, - ) -> Result>> { - let mut events = Vec::new(); - - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - prefix.extend_from_slice(device_id.as_bytes()); - prefix.push(0xFF); - - for (_, value) in self.todeviceid_events.scan_prefix(prefix) { - events.push( - serde_json::from_slice(&value) - .map_err(|e| err!(Database("Event in todeviceid_events is invalid. {e}")))?, - ); - } - - Ok(events) - } - - pub(super) fn remove_to_device_events(&self, user_id: &UserId, device_id: &DeviceId, until: u64) -> Result<()> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - prefix.extend_from_slice(device_id.as_bytes()); - prefix.push(0xFF); - - let mut last = prefix.clone(); - last.extend_from_slice(&until.to_be_bytes()); - - for (key, _) in self - .todeviceid_events - .iter_from(&last, true) // this includes last - .take_while(move |(k, _)| k.starts_with(&prefix)) - .map(|(key, _)| { - Ok::<_, Error>(( - key.clone(), - utils::u64_from_bytes(&key[key.len().saturating_sub(size_of::())..key.len()]) - .map_err(|e| err!(Database("ToDeviceId has invalid count bytes. {e}")))?, - )) - }) - .filter_map(Result::ok) - .take_while(|&(_, count)| count <= until) - { - self.todeviceid_events.remove(&key)?; - } - - Ok(()) - } - - pub(super) fn update_device_metadata(&self, user_id: &UserId, device_id: &DeviceId, device: &Device) -> Result<()> { - let mut userdeviceid = user_id.as_bytes().to_vec(); - userdeviceid.push(0xFF); - userdeviceid.extend_from_slice(device_id.as_bytes()); - - // Only existing devices should be able to call this, but we shouldn't assert - // either... - if self.userdeviceid_metadata.get(&userdeviceid)?.is_none() { - warn!( - "Called update_device_metadata for a non-existent user \"{}\" and/or device ID \"{}\" with no \ - metadata in database", - user_id, device_id - ); - return Err(Error::bad_database( - "User does not exist or device ID has no metadata in database.", - )); - } - - self.userid_devicelistversion - .increment(user_id.as_bytes())?; - - self.userdeviceid_metadata.insert( - &userdeviceid, - &serde_json::to_vec(device).expect("Device::to_string always works"), - )?; - - Ok(()) - } - - /// Get device metadata. - pub(super) fn get_device_metadata(&self, user_id: &UserId, device_id: &DeviceId) -> Result> { - let mut userdeviceid = user_id.as_bytes().to_vec(); - userdeviceid.push(0xFF); - userdeviceid.extend_from_slice(device_id.as_bytes()); - - self.userdeviceid_metadata - .get(&userdeviceid)? - .map_or(Ok(None), |bytes| { - Ok(Some(serde_json::from_slice(&bytes).map_err(|_| { - Error::bad_database("Metadata in userdeviceid_metadata is invalid.") - })?)) - }) - } - - pub(super) fn get_devicelist_version(&self, user_id: &UserId) -> Result> { - self.userid_devicelistversion - .get(user_id.as_bytes())? - .map_or(Ok(None), |bytes| { - utils::u64_from_bytes(&bytes) - .map_err(|e| err!(Database("Invalid devicelistversion in db. {e}"))) - .map(Some) - }) - } - - pub(super) fn all_devices_metadata<'a>( - &'a self, user_id: &UserId, - ) -> Box> + 'a> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - - Box::new( - self.userdeviceid_metadata - .scan_prefix(key) - .map(|(_, bytes)| { - serde_json::from_slice::(&bytes) - .map_err(|e| err!(Database("Device in userdeviceid_metadata is invalid. {e}"))) - }), - ) - } - - /// Creates a new sync filter. Returns the filter id. - pub(super) fn create_filter(&self, user_id: &UserId, filter: &FilterDefinition) -> Result { - let filter_id = utils::random_string(4); - - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(filter_id.as_bytes()); - - self.userfilterid_filter - .insert(&key, &serde_json::to_vec(&filter).expect("filter is valid json"))?; - - Ok(filter_id) - } - - pub(super) fn get_filter(&self, user_id: &UserId, filter_id: &str) -> Result> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(filter_id.as_bytes()); - - let raw = self.userfilterid_filter.get(&key)?; - - if let Some(raw) = raw { - serde_json::from_slice(&raw).map_err(|e| err!(Database("Invalid filter event in db. {e}"))) - } else { - Ok(None) - } - } - - /// Creates an OpenID token, which can be used to prove that a user has - /// access to an account (primarily for integrations) - pub(super) fn create_openid_token(&self, user_id: &UserId, token: &str) -> Result { - use std::num::Saturating as Sat; - - let expires_in = self.services.server.config.openid_token_ttl; - let expires_at = Sat(utils::millis_since_unix_epoch()) + Sat(expires_in) * Sat(1000); - - let mut value = expires_at.0.to_be_bytes().to_vec(); - value.extend_from_slice(user_id.as_bytes()); - - self.openidtoken_expiresatuserid - .insert(token.as_bytes(), value.as_slice())?; - - Ok(expires_in) - } - - /// Find out which user an OpenID access token belongs to. - pub(super) fn find_from_openid_token(&self, token: &str) -> Result { - let Some(value) = self.openidtoken_expiresatuserid.get(token.as_bytes())? else { - return Err(Error::BadRequest(ErrorKind::Unauthorized, "OpenID token is unrecognised")); - }; - - let (expires_at_bytes, user_bytes) = value.split_at(0_u64.to_be_bytes().len()); - - let expires_at = u64::from_be_bytes( - expires_at_bytes - .try_into() - .map_err(|e| err!(Database("expires_at in openid_userid is invalid u64. {e}")))?, - ); - - if expires_at < utils::millis_since_unix_epoch() { - debug_info!("OpenID token is expired, removing"); - self.openidtoken_expiresatuserid.remove(token.as_bytes())?; - - return Err(Error::BadRequest(ErrorKind::Unauthorized, "OpenID token is expired")); - } - - UserId::parse( - utils::string_from_bytes(user_bytes) - .map_err(|e| err!(Database("User ID in openid_userid is invalid unicode. {e}")))?, - ) - .map_err(|e| err!(Database("User ID in openid_userid is invalid. {e}"))) - } -} - -/// Will only return with Some(username) if the password was not empty and the -/// username could be successfully parsed. -/// If `utils::string_from_bytes`(...) returns an error that username will be -/// skipped and the error will be logged. -pub(super) fn get_username_with_valid_password(username: &[u8], password: &[u8]) -> Option { - // A valid password is not empty - if password.is_empty() { - None - } else { - match utils::string_from_bytes(username) { - Ok(u) => Some(u), - Err(e) => { - warn!("Failed to parse username while calling get_local_users(): {}", e.to_string()); - None - }, - } - } -} diff --git a/src/service/users/mod.rs b/src/service/users/mod.rs index 80897b5ff..438c220bc 100644 --- a/src/service/users/mod.rs +++ b/src/service/users/mod.rs @@ -1,552 +1,986 @@ -mod data; +use std::{collections::BTreeMap, mem, mem::size_of, sync::Arc}; -use std::{ - collections::{BTreeMap, BTreeSet}, - mem, - sync::{Arc, Mutex, Mutex as StdMutex}, +use conduit::{ + debug_warn, err, utils, + utils::{stream::TryIgnore, string::Unquoted, ReadyExt, TryReadyExt}, + warn, Err, Error, Result, Server, }; - -use conduit::{Error, Result}; +use database::{Deserialized, Ignore, Interfix, Map}; +use futures::{pin_mut, FutureExt, Stream, StreamExt, TryFutureExt}; use ruma::{ - api::client::{ - device::Device, - filter::FilterDefinition, - sync::sync_events::{ - self, - v4::{ExtensionsConfig, SyncRequestList}, - }, - }, + api::client::{device::Device, error::ErrorKind, filter::FilterDefinition}, encryption::{CrossSigningKey, DeviceKeys, OneTimeKey}, - events::AnyToDeviceEvent, + events::{AnyToDeviceEvent, StateEventType}, serde::Raw, - DeviceId, DeviceKeyAlgorithm, DeviceKeyId, OwnedDeviceId, OwnedDeviceKeyId, OwnedMxcUri, OwnedRoomId, OwnedUserId, - UInt, UserId, + DeviceId, DeviceKeyAlgorithm, DeviceKeyId, MilliSecondsSinceUnixEpoch, OwnedDeviceId, OwnedDeviceKeyId, + OwnedMxcUri, OwnedUserId, UInt, UserId, }; -use self::data::Data; -use crate::{admin, rooms, Dep}; +use crate::{admin, globals, rooms, Dep}; pub struct Service { - connections: DbConnections, - pub db: Data, services: Services, + db: Data, } struct Services { + server: Arc, admin: Dep, + globals: Dep, + state_accessor: Dep, state_cache: Dep, } +struct Data { + keychangeid_userid: Arc, + keyid_key: Arc, + onetimekeyid_onetimekeys: Arc, + openidtoken_expiresatuserid: Arc, + todeviceid_events: Arc, + token_userdeviceid: Arc, + userdeviceid_metadata: Arc, + userdeviceid_token: Arc, + userfilterid_filter: Arc, + userid_avatarurl: Arc, + userid_blurhash: Arc, + userid_devicelistversion: Arc, + userid_displayname: Arc, + userid_lastonetimekeyupdate: Arc, + userid_masterkeyid: Arc, + userid_password: Arc, + userid_selfsigningkeyid: Arc, + userid_usersigningkeyid: Arc, + useridprofilekey_value: Arc, +} + impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { - connections: StdMutex::new(BTreeMap::new()), - db: Data::new(&args), services: Services { + server: args.server.clone(), admin: args.depend::("admin"), + globals: args.depend::("globals"), + state_accessor: args.depend::("rooms::state_accessor"), state_cache: args.depend::("rooms::state_cache"), }, + db: Data { + keychangeid_userid: args.db["keychangeid_userid"].clone(), + keyid_key: args.db["keyid_key"].clone(), + onetimekeyid_onetimekeys: args.db["onetimekeyid_onetimekeys"].clone(), + openidtoken_expiresatuserid: args.db["openidtoken_expiresatuserid"].clone(), + todeviceid_events: args.db["todeviceid_events"].clone(), + token_userdeviceid: args.db["token_userdeviceid"].clone(), + userdeviceid_metadata: args.db["userdeviceid_metadata"].clone(), + userdeviceid_token: args.db["userdeviceid_token"].clone(), + userfilterid_filter: args.db["userfilterid_filter"].clone(), + userid_avatarurl: args.db["userid_avatarurl"].clone(), + userid_blurhash: args.db["userid_blurhash"].clone(), + userid_devicelistversion: args.db["userid_devicelistversion"].clone(), + userid_displayname: args.db["userid_displayname"].clone(), + userid_lastonetimekeyupdate: args.db["userid_lastonetimekeyupdate"].clone(), + userid_masterkeyid: args.db["userid_masterkeyid"].clone(), + userid_password: args.db["userid_password"].clone(), + userid_selfsigningkeyid: args.db["userid_selfsigningkeyid"].clone(), + userid_usersigningkeyid: args.db["userid_usersigningkeyid"].clone(), + useridprofilekey_value: args.db["useridprofilekey_value"].clone(), + }, })) } fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } -type DbConnections = Mutex>; -type DbConnectionsKey = (OwnedUserId, OwnedDeviceId, String); -type DbConnectionsVal = Arc>; - -struct SlidingSyncCache { - lists: BTreeMap, - subscriptions: BTreeMap, - known_rooms: BTreeMap>, // For every room, the roomsince number - extensions: ExtensionsConfig, -} - impl Service { - /// Check if a user has an account on this homeserver. + /// Check if a user is an admin #[inline] - pub fn exists(&self, user_id: &UserId) -> Result { self.db.exists(user_id) } - - pub fn remembered(&self, user_id: OwnedUserId, device_id: OwnedDeviceId, conn_id: String) -> bool { - self.connections - .lock() - .unwrap() - .contains_key(&(user_id, device_id, conn_id)) - } + pub async fn is_admin(&self, user_id: &UserId) -> bool { self.services.admin.user_is_admin(user_id).await } - pub fn forget_sync_request_connection(&self, user_id: OwnedUserId, device_id: OwnedDeviceId, conn_id: String) { - self.connections - .lock() - .unwrap() - .remove(&(user_id, device_id, conn_id)); + /// Create a new user account on this homeserver. + #[inline] + pub fn create(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { + self.set_password(user_id, password) } - pub fn update_sync_request_with_cache( - &self, user_id: OwnedUserId, device_id: OwnedDeviceId, request: &mut sync_events::v4::Request, - ) -> BTreeMap> { - let Some(conn_id) = request.conn_id.clone() else { - return BTreeMap::new(); - }; - - let mut cache = self.connections.lock().unwrap(); - let cached = Arc::clone( - cache - .entry((user_id, device_id, conn_id)) - .or_insert_with(|| { - Arc::new(Mutex::new(SlidingSyncCache { - lists: BTreeMap::new(), - subscriptions: BTreeMap::new(), - known_rooms: BTreeMap::new(), - extensions: ExtensionsConfig::default(), - })) - }), - ); - let cached = &mut cached.lock().unwrap(); - drop(cache); - - for (list_id, list) in &mut request.lists { - if let Some(cached_list) = cached.lists.get(list_id) { - if list.sort.is_empty() { - list.sort.clone_from(&cached_list.sort); - }; - if list.room_details.required_state.is_empty() { - list.room_details - .required_state - .clone_from(&cached_list.room_details.required_state); - }; - list.room_details.timeline_limit = list - .room_details - .timeline_limit - .or(cached_list.room_details.timeline_limit); - list.include_old_rooms = list - .include_old_rooms - .clone() - .or_else(|| cached_list.include_old_rooms.clone()); - match (&mut list.filters, cached_list.filters.clone()) { - (Some(list_filters), Some(cached_filters)) => { - list_filters.is_dm = list_filters.is_dm.or(cached_filters.is_dm); - if list_filters.spaces.is_empty() { - list_filters.spaces = cached_filters.spaces; - } - list_filters.is_encrypted = list_filters.is_encrypted.or(cached_filters.is_encrypted); - list_filters.is_invite = list_filters.is_invite.or(cached_filters.is_invite); - if list_filters.room_types.is_empty() { - list_filters.room_types = cached_filters.room_types; - } - if list_filters.not_room_types.is_empty() { - list_filters.not_room_types = cached_filters.not_room_types; - } - list_filters.room_name_like = list_filters - .room_name_like - .clone() - .or(cached_filters.room_name_like); - if list_filters.tags.is_empty() { - list_filters.tags = cached_filters.tags; - } - if list_filters.not_tags.is_empty() { - list_filters.not_tags = cached_filters.not_tags; - } - }, - (_, Some(cached_filters)) => list.filters = Some(cached_filters), - (Some(list_filters), _) => list.filters = Some(list_filters.clone()), - (..) => {}, - } - if list.bump_event_types.is_empty() { - list.bump_event_types - .clone_from(&cached_list.bump_event_types); - }; - } - cached.lists.insert(list_id.clone(), list.clone()); - } + /// Deactivate account + pub async fn deactivate_account(&self, user_id: &UserId) -> Result<()> { + // Remove all associated devices + self.all_device_ids(user_id) + .for_each(|device_id| self.remove_device(user_id, device_id)) + .await; - cached - .subscriptions - .extend(request.room_subscriptions.clone()); - request - .room_subscriptions - .extend(cached.subscriptions.clone()); - - request.extensions.e2ee.enabled = request - .extensions - .e2ee - .enabled - .or(cached.extensions.e2ee.enabled); - - request.extensions.to_device.enabled = request - .extensions - .to_device - .enabled - .or(cached.extensions.to_device.enabled); - - request.extensions.account_data.enabled = request - .extensions - .account_data - .enabled - .or(cached.extensions.account_data.enabled); - request.extensions.account_data.lists = request - .extensions - .account_data - .lists - .clone() - .or_else(|| cached.extensions.account_data.lists.clone()); - request.extensions.account_data.rooms = request - .extensions - .account_data - .rooms - .clone() - .or_else(|| cached.extensions.account_data.rooms.clone()); - - cached.extensions = request.extensions.clone(); - - cached.known_rooms.clone() - } - - pub fn update_sync_subscriptions( - &self, user_id: OwnedUserId, device_id: OwnedDeviceId, conn_id: String, - subscriptions: BTreeMap, - ) { - let mut cache = self.connections.lock().unwrap(); - let cached = Arc::clone( - cache - .entry((user_id, device_id, conn_id)) - .or_insert_with(|| { - Arc::new(Mutex::new(SlidingSyncCache { - lists: BTreeMap::new(), - subscriptions: BTreeMap::new(), - known_rooms: BTreeMap::new(), - extensions: ExtensionsConfig::default(), - })) - }), - ); - let cached = &mut cached.lock().unwrap(); - drop(cache); + // Set the password to "" to indicate a deactivated account. Hashes will never + // result in an empty string, so the user will not be able to log in again. + // Systems like changing the password without logging in should check if the + // account is deactivated. + self.set_password(user_id, None)?; - cached.subscriptions = subscriptions; + // TODO: Unhook 3PID + Ok(()) } - pub fn update_sync_known_rooms( - &self, user_id: OwnedUserId, device_id: OwnedDeviceId, conn_id: String, list_id: String, - new_cached_rooms: BTreeSet, globalsince: u64, - ) { - let mut cache = self.connections.lock().unwrap(); - let cached = Arc::clone( - cache - .entry((user_id, device_id, conn_id)) - .or_insert_with(|| { - Arc::new(Mutex::new(SlidingSyncCache { - lists: BTreeMap::new(), - subscriptions: BTreeMap::new(), - known_rooms: BTreeMap::new(), - extensions: ExtensionsConfig::default(), - })) - }), - ); - let cached = &mut cached.lock().unwrap(); - drop(cache); - - for (roomid, lastsince) in cached - .known_rooms - .entry(list_id.clone()) - .or_default() - .iter_mut() - { - if !new_cached_rooms.contains(roomid) { - *lastsince = 0; - } - } - let list = cached.known_rooms.entry(list_id).or_default(); - for roomid in new_cached_rooms { - list.insert(roomid, globalsince); - } - } + /// Check if a user has an account on this homeserver. + #[inline] + pub async fn exists(&self, user_id: &UserId) -> bool { self.db.userid_password.get(user_id).await.is_ok() } /// Check if account is deactivated - pub fn is_deactivated(&self, user_id: &UserId) -> Result { self.db.is_deactivated(user_id) } - - /// Check if a user is an admin - pub fn is_admin(&self, user_id: &UserId) -> Result { - if let Some(admin_room_id) = self.services.admin.get_admin_room()? { - self.services.state_cache.is_joined(user_id, &admin_room_id) - } else { - Ok(false) - } + pub async fn is_deactivated(&self, user_id: &UserId) -> Result { + self.db + .userid_password + .get(user_id) + .map_ok(|val| val.is_empty()) + .map_err(|_| err!(Request(NotFound("User does not exist.")))) + .await } - /// Create a new user account on this homeserver. - #[inline] - pub fn create(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { - self.db.set_password(user_id, password)?; - Ok(()) + /// Check if account is active, infallible + pub async fn is_active(&self, user_id: &UserId) -> bool { !self.is_deactivated(user_id).await.unwrap_or(true) } + + /// Check if account is active, infallible + pub async fn is_active_local(&self, user_id: &UserId) -> bool { + self.services.globals.user_is_local(user_id) && self.is_active(user_id).await } /// Returns the number of users registered on this server. #[inline] - pub fn count(&self) -> Result { self.db.count() } + pub async fn count(&self) -> usize { self.db.userid_password.count().await } /// Find out which user an access token belongs to. - pub fn find_from_token(&self, token: &str) -> Result> { - self.db.find_from_token(token) + pub async fn find_from_token(&self, token: &str) -> Result<(OwnedUserId, OwnedDeviceId)> { + self.db.token_userdeviceid.get(token).await.deserialized() } + /// Returns an iterator over all users on this homeserver (offered for + /// compatibility) + #[allow(clippy::iter_without_into_iter, clippy::iter_not_returning_iterator)] + pub fn iter(&self) -> impl Stream + Send + '_ { self.stream().map(ToOwned::to_owned) } + /// Returns an iterator over all users on this homeserver. - pub fn iter(&self) -> impl Iterator> + '_ { self.db.iter() } + pub fn stream(&self) -> impl Stream + Send { self.db.userid_password.keys().ignore_err() } /// Returns a list of local users as list of usernames. /// /// A user account is considered `local` if the length of it's password is /// greater then zero. - pub fn list_local_users(&self) -> Result> { self.db.list_local_users() } + pub fn list_local_users(&self) -> impl Stream + Send + '_ { + self.db + .userid_password + .stream() + .ignore_err() + .ready_filter_map(|(u, p): (&UserId, &[u8])| (!p.is_empty()).then_some(u)) + } /// Returns the password hash for the given user. - pub fn password_hash(&self, user_id: &UserId) -> Result> { self.db.password_hash(user_id) } + pub async fn password_hash(&self, user_id: &UserId) -> Result { + self.db.userid_password.get(user_id).await.deserialized() + } /// Hash and set the user's password to the Argon2 hash - #[inline] pub fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { - self.db.set_password(user_id, password) + if let Some(password) = password { + if let Ok(hash) = utils::hash::password(password) { + self.db + .userid_password + .insert(user_id.as_bytes(), hash.as_bytes()); + Ok(()) + } else { + Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Password does not meet the requirements.", + )) + } + } else { + self.db.userid_password.insert(user_id.as_bytes(), b""); + Ok(()) + } } /// Returns the displayname of a user on this homeserver. - pub fn displayname(&self, user_id: &UserId) -> Result> { self.db.displayname(user_id) } + pub async fn displayname(&self, user_id: &UserId) -> Result { + self.db.userid_displayname.get(user_id).await.deserialized() + } /// Sets a new displayname or removes it if displayname is None. You still /// need to nofify all rooms of this change. - pub async fn set_displayname(&self, user_id: &UserId, displayname: Option) -> Result<()> { - self.db.set_displayname(user_id, displayname) + pub fn set_displayname(&self, user_id: &UserId, displayname: Option) { + if let Some(displayname) = displayname { + self.db + .userid_displayname + .insert(user_id.as_bytes(), displayname.as_bytes()); + } else { + self.db.userid_displayname.remove(user_id.as_bytes()); + } } - /// Get the avatar_url of a user. - pub fn avatar_url(&self, user_id: &UserId) -> Result> { self.db.avatar_url(user_id) } + /// Get the `avatar_url` of a user. + pub async fn avatar_url(&self, user_id: &UserId) -> Result { + self.db.userid_avatarurl.get(user_id).await.deserialized() + } /// Sets a new avatar_url or removes it if avatar_url is None. - pub async fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option) -> Result<()> { - self.db.set_avatar_url(user_id, avatar_url) + pub fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option) { + if let Some(avatar_url) = avatar_url { + self.db + .userid_avatarurl + .insert(user_id.as_bytes(), avatar_url.to_string().as_bytes()); + } else { + self.db.userid_avatarurl.remove(user_id.as_bytes()); + } } /// Get the blurhash of a user. - pub fn blurhash(&self, user_id: &UserId) -> Result> { self.db.blurhash(user_id) } - - pub fn timezone(&self, user_id: &UserId) -> Result> { self.db.timezone(user_id) } - - /// Gets a specific user profile key - pub fn profile_key(&self, user_id: &UserId, profile_key: &str) -> Result> { - self.db.profile_key(user_id, profile_key) + pub async fn blurhash(&self, user_id: &UserId) -> Result { + self.db.userid_blurhash.get(user_id).await.deserialized() } - /// Gets all the user's profile keys and values in an iterator - pub fn all_profile_keys<'a>( - &'a self, user_id: &UserId, - ) -> Box> + 'a + Send> { - self.db.all_profile_keys(user_id) - } - - /// Sets a new profile key value, removes the key if value is None - pub fn set_profile_key( - &self, user_id: &UserId, profile_key: &str, profile_key_value: Option, - ) -> Result<()> { - self.db - .set_profile_key(user_id, profile_key, profile_key_value) - } - - /// Sets a new tz or removes it if tz is None. - pub async fn set_timezone(&self, user_id: &UserId, tz: Option) -> Result<()> { - self.db.set_timezone(user_id, tz) - } - - /// Sets a new blurhash or removes it if blurhash is None. - pub async fn set_blurhash(&self, user_id: &UserId, blurhash: Option) -> Result<()> { - self.db.set_blurhash(user_id, blurhash) + /// Sets a new avatar_url or removes it if avatar_url is None. + pub fn set_blurhash(&self, user_id: &UserId, blurhash: Option) { + if let Some(blurhash) = blurhash { + self.db + .userid_blurhash + .insert(user_id.as_bytes(), blurhash.as_bytes()); + } else { + self.db.userid_blurhash.remove(user_id.as_bytes()); + } } /// Adds a new device to a user. - pub fn create_device( + pub async fn create_device( &self, user_id: &UserId, device_id: &DeviceId, token: &str, initial_device_display_name: Option, client_ip: Option, ) -> Result<()> { - self.db - .create_device(user_id, device_id, token, initial_device_display_name, client_ip) + // This method should never be called for nonexistent users. We shouldn't assert + // though... + if !self.exists(user_id).await { + warn!("Called create_device for non-existent user {} in database", user_id); + return Err(Error::BadRequest(ErrorKind::InvalidParam, "User does not exist.")); + } + + let mut userdeviceid = user_id.as_bytes().to_vec(); + userdeviceid.push(0xFF); + userdeviceid.extend_from_slice(device_id.as_bytes()); + + increment(&self.db.userid_devicelistversion, user_id.as_bytes()); + + self.db.userdeviceid_metadata.insert( + &userdeviceid, + &serde_json::to_vec(&Device { + device_id: device_id.into(), + display_name: initial_device_display_name, + last_seen_ip: client_ip, + last_seen_ts: Some(MilliSecondsSinceUnixEpoch::now()), + }) + .expect("Device::to_string never fails."), + ); + + self.set_token(user_id, device_id, token).await?; + + Ok(()) } /// Removes a device from a user. - pub fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { - self.db.remove_device(user_id, device_id) + pub async fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) { + let mut userdeviceid = user_id.as_bytes().to_vec(); + userdeviceid.push(0xFF); + userdeviceid.extend_from_slice(device_id.as_bytes()); + + // Remove tokens + if let Ok(old_token) = self.db.userdeviceid_token.get(&userdeviceid).await { + self.db.userdeviceid_token.remove(&userdeviceid); + self.db.token_userdeviceid.remove(&old_token); + } + + // Remove todevice events + let prefix = (user_id, device_id, Interfix); + self.db + .todeviceid_events + .keys_raw_prefix(&prefix) + .ignore_err() + .ready_for_each(|key| self.db.todeviceid_events.remove(key)) + .await; + + // TODO: Remove onetimekeys + + increment(&self.db.userid_devicelistversion, user_id.as_bytes()); + + self.db.userdeviceid_metadata.remove(&userdeviceid); } /// Returns an iterator over all device ids of this user. - pub fn all_device_ids<'a>(&'a self, user_id: &UserId) -> impl Iterator> + 'a { - self.db.all_device_ids(user_id) + pub fn all_device_ids<'a>(&'a self, user_id: &'a UserId) -> impl Stream + Send + 'a { + let prefix = (user_id, Interfix); + self.db + .userdeviceid_metadata + .keys_prefix(&prefix) + .ignore_err() + .map(|(_, device_id): (Ignore, &DeviceId)| device_id) } /// Replaces the access token of one device. - #[inline] - pub fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()> { - self.db.set_token(user_id, device_id, token) + pub async fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()> { + let key = (user_id, device_id); + // should not be None, but we shouldn't assert either lol... + if self.db.userdeviceid_metadata.qry(&key).await.is_err() { + return Err!(Database(error!( + ?user_id, + ?device_id, + "User does not exist or device has no metadata." + ))); + } + + // Remove old token + if let Ok(old_token) = self.db.userdeviceid_token.qry(&key).await { + self.db.token_userdeviceid.remove(&old_token); + // It will be removed from userdeviceid_token by the insert later + } + + // Assign token to user device combination + let mut userdeviceid = user_id.as_bytes().to_vec(); + userdeviceid.push(0xFF); + userdeviceid.extend_from_slice(device_id.as_bytes()); + self.db + .userdeviceid_token + .insert(&userdeviceid, token.as_bytes()); + self.db + .token_userdeviceid + .insert(token.as_bytes(), &userdeviceid); + + Ok(()) } - pub fn add_one_time_key( + pub async fn add_one_time_key( &self, user_id: &UserId, device_id: &DeviceId, one_time_key_key: &DeviceKeyId, one_time_key_value: &Raw, ) -> Result<()> { + // All devices have metadata + // Only existing devices should be able to call this, but we shouldn't assert + // either... + let key = (user_id, device_id); + if self.db.userdeviceid_metadata.qry(&key).await.is_err() { + return Err!(Database(error!( + ?user_id, + ?device_id, + "User does not exist or device has no metadata." + ))); + } + + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(device_id.as_bytes()); + key.push(0xFF); + // TODO: Use DeviceKeyId::to_string when it's available (and update everything, + // because there are no wrapping quotation marks anymore) + key.extend_from_slice( + serde_json::to_string(one_time_key_key) + .expect("DeviceKeyId::to_string always works") + .as_bytes(), + ); + + self.db.onetimekeyid_onetimekeys.insert( + &key, + &serde_json::to_vec(&one_time_key_value).expect("OneTimeKey::to_vec always works"), + ); + self.db - .add_one_time_key(user_id, device_id, one_time_key_key, one_time_key_value) - } + .userid_lastonetimekeyupdate + .insert(user_id.as_bytes(), &self.services.globals.next_count()?.to_be_bytes()); - // TODO: use this ? - #[allow(dead_code)] - pub fn last_one_time_keys_update(&self, user_id: &UserId) -> Result { - self.db.last_one_time_keys_update(user_id) + Ok(()) } - pub fn take_one_time_key( - &self, user_id: &UserId, device_id: &DeviceId, key_algorithm: &DeviceKeyAlgorithm, - ) -> Result)>> { - self.db.take_one_time_key(user_id, device_id, key_algorithm) + pub async fn last_one_time_keys_update(&self, user_id: &UserId) -> u64 { + self.db + .userid_lastonetimekeyupdate + .get(user_id) + .await + .deserialized() + .unwrap_or(0) } - pub fn count_one_time_keys( + pub async fn take_one_time_key( + &self, user_id: &UserId, device_id: &DeviceId, key_algorithm: &DeviceKeyAlgorithm, + ) -> Result<(OwnedDeviceKeyId, Raw)> { + self.db + .userid_lastonetimekeyupdate + .insert(user_id.as_bytes(), &self.services.globals.next_count()?.to_be_bytes()); + + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xFF); + prefix.extend_from_slice(device_id.as_bytes()); + prefix.push(0xFF); + prefix.push(b'"'); // Annoying quotation mark + prefix.extend_from_slice(key_algorithm.as_ref().as_bytes()); + prefix.push(b':'); + + let one_time_key = self + .db + .onetimekeyid_onetimekeys + .raw_stream_prefix(&prefix) + .ignore_err() + .map(|(key, val)| { + self.db.onetimekeyid_onetimekeys.remove(key); + + let key = key + .rsplit(|&b| b == 0xFF) + .next() + .ok_or_else(|| err!(Database("OneTimeKeyId in db is invalid."))) + .unwrap(); + + let key = serde_json::from_slice(key) + .map_err(|e| err!(Database("OneTimeKeyId in db is invalid. {e}"))) + .unwrap(); + + let val = serde_json::from_slice(val) + .map_err(|e| err!(Database("OneTimeKeys in db are invalid. {e}"))) + .unwrap(); + + (key, val) + }) + .next() + .await; + + one_time_key.ok_or_else(|| err!(Request(NotFound("No one-time-key found")))) + } + + pub async fn count_one_time_keys( &self, user_id: &UserId, device_id: &DeviceId, - ) -> Result> { - self.db.count_one_time_keys(user_id, device_id) - } + ) -> BTreeMap { + type KeyVal<'a> = ((Ignore, Ignore, &'a Unquoted), Ignore); + + let mut algorithm_counts = BTreeMap::::new(); + let query = (user_id, device_id); + self.db + .onetimekeyid_onetimekeys + .stream_prefix(&query) + .ignore_err() + .ready_for_each(|((Ignore, Ignore, device_key_id), Ignore): KeyVal<'_>| { + let device_key_id: &DeviceKeyId = device_key_id + .as_str() + .try_into() + .expect("Invalid DeviceKeyID in database"); + + let count: &mut UInt = algorithm_counts + .entry(device_key_id.algorithm()) + .or_default(); + + *count = count.saturating_add(1_u32.into()); + }) + .await; + + algorithm_counts + } + + pub async fn add_device_keys(&self, user_id: &UserId, device_id: &DeviceId, device_keys: &Raw) { + let mut userdeviceid = user_id.as_bytes().to_vec(); + userdeviceid.push(0xFF); + userdeviceid.extend_from_slice(device_id.as_bytes()); + + self.db.keyid_key.insert( + &userdeviceid, + &serde_json::to_vec(&device_keys).expect("DeviceKeys::to_vec always works"), + ); - pub fn add_device_keys(&self, user_id: &UserId, device_id: &DeviceId, device_keys: &Raw) -> Result<()> { - self.db.add_device_keys(user_id, device_id, device_keys) + self.mark_device_key_update(user_id).await; } - pub fn add_cross_signing_keys( + pub async fn add_cross_signing_keys( &self, user_id: &UserId, master_key: &Raw, self_signing_key: &Option>, user_signing_key: &Option>, notify: bool, ) -> Result<()> { + // TODO: Check signatures + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xFF); + + let (master_key_key, _) = parse_master_key(user_id, master_key)?; + + self.db + .keyid_key + .insert(&master_key_key, master_key.json().get().as_bytes()); + self.db - .add_cross_signing_keys(user_id, master_key, self_signing_key, user_signing_key, notify) + .userid_masterkeyid + .insert(user_id.as_bytes(), &master_key_key); + + // Self-signing key + if let Some(self_signing_key) = self_signing_key { + let mut self_signing_key_ids = self_signing_key + .deserialize() + .map_err(|e| err!(Request(InvalidParam("Invalid self signing key: {e:?}"))))? + .keys + .into_values(); + + let self_signing_key_id = self_signing_key_ids + .next() + .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Self signing key contained no key."))?; + + if self_signing_key_ids.next().is_some() { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Self signing key contained more than one key.", + )); + } + + let mut self_signing_key_key = prefix.clone(); + self_signing_key_key.extend_from_slice(self_signing_key_id.as_bytes()); + + self.db + .keyid_key + .insert(&self_signing_key_key, self_signing_key.json().get().as_bytes()); + + self.db + .userid_selfsigningkeyid + .insert(user_id.as_bytes(), &self_signing_key_key); + } + + // User-signing key + if let Some(user_signing_key) = user_signing_key { + let mut user_signing_key_ids = user_signing_key + .deserialize() + .map_err(|_| err!(Request(InvalidParam("Invalid user signing key"))))? + .keys + .into_values(); + + let user_signing_key_id = user_signing_key_ids + .next() + .ok_or(err!(Request(InvalidParam("User signing key contained no key."))))?; + + if user_signing_key_ids.next().is_some() { + return Err!(Request(InvalidParam("User signing key contained more than one key."))); + } + + let mut user_signing_key_key = prefix; + user_signing_key_key.extend_from_slice(user_signing_key_id.as_bytes()); + + self.db + .keyid_key + .insert(&user_signing_key_key, user_signing_key.json().get().as_bytes()); + + self.db + .userid_usersigningkeyid + .insert(user_id.as_bytes(), &user_signing_key_key); + } + + if notify { + self.mark_device_key_update(user_id).await; + } + + Ok(()) } - pub fn sign_key( + pub async fn sign_key( &self, target_id: &UserId, key_id: &str, signature: (String, String), sender_id: &UserId, ) -> Result<()> { - self.db.sign_key(target_id, key_id, signature, sender_id) + let key = (target_id, key_id); + + let mut cross_signing_key: serde_json::Value = self + .db + .keyid_key + .qry(&key) + .await + .map_err(|_| err!(Request(InvalidParam("Tried to sign nonexistent key."))))? + .deserialized() + .map_err(|e| err!(Database("key in keyid_key is invalid. {e:?}")))?; + + let signatures = cross_signing_key + .get_mut("signatures") + .ok_or_else(|| err!(Database("key in keyid_key has no signatures field.")))? + .as_object_mut() + .ok_or_else(|| err!(Database("key in keyid_key has invalid signatures field.")))? + .entry(sender_id.to_string()) + .or_insert_with(|| serde_json::Map::new().into()); + + signatures + .as_object_mut() + .ok_or_else(|| err!(Database("signatures in keyid_key for a user is invalid.")))? + .insert(signature.0, signature.1.into()); + + let mut key = target_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(key_id.as_bytes()); + self.db.keyid_key.insert( + &key, + &serde_json::to_vec(&cross_signing_key).expect("CrossSigningKey::to_vec always works"), + ); + + self.mark_device_key_update(target_id).await; + + Ok(()) } pub fn keys_changed<'a>( - &'a self, user_or_room_id: &str, from: u64, to: Option, - ) -> impl Iterator> + 'a { - self.db.keys_changed(user_or_room_id, from, to) - } + &'a self, user_or_room_id: &'a str, from: u64, to: Option, + ) -> impl Stream + Send + 'a { + type KeyVal<'a> = ((&'a str, u64), &'a UserId); - #[inline] - pub fn mark_device_key_update(&self, user_id: &UserId) -> Result<()> { self.db.mark_device_key_update(user_id) } + let to = to.unwrap_or(u64::MAX); + let start = (user_or_room_id, from.saturating_add(1)); + self.db + .keychangeid_userid + .stream_from(&start) + .ignore_err() + .ready_take_while(move |((prefix, count), _): &KeyVal<'_>| *prefix == user_or_room_id && *count <= to) + .map(|((..), user_id): KeyVal<'_>| user_id) + } + + pub async fn mark_device_key_update(&self, user_id: &UserId) { + let count = self.services.globals.next_count().unwrap().to_be_bytes(); + + let rooms_joined = self.services.state_cache.rooms_joined(user_id); + + pin_mut!(rooms_joined); + while let Some(room_id) = rooms_joined.next().await { + // Don't send key updates to unencrypted rooms + if self + .services + .state_accessor + .room_state_get(room_id, &StateEventType::RoomEncryption, "") + .await + .is_err() + { + continue; + } - pub fn get_device_keys(&self, user_id: &UserId, device_id: &DeviceId) -> Result>> { - self.db.get_device_keys(user_id, device_id) - } + let mut key = room_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(&count); - #[inline] - pub fn parse_master_key( - &self, user_id: &UserId, master_key: &Raw, - ) -> Result<(Vec, CrossSigningKey)> { - Data::parse_master_key(user_id, master_key) + self.db.keychangeid_userid.insert(&key, user_id.as_bytes()); + } + + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(&count); + self.db.keychangeid_userid.insert(&key, user_id.as_bytes()); } - #[inline] - pub fn get_key( - &self, key: &[u8], sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool, - ) -> Result>> { - self.db - .get_key(key, sender_user, user_id, allowed_signatures) + pub async fn get_device_keys<'a>(&'a self, user_id: &'a UserId, device_id: &DeviceId) -> Result> { + let key_id = (user_id, device_id); + self.db.keyid_key.qry(&key_id).await.deserialized() } - pub fn get_master_key( - &self, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool, - ) -> Result>> { - self.db - .get_master_key(sender_user, user_id, allowed_signatures) + pub async fn get_key( + &self, key_id: &[u8], sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &F, + ) -> Result> + where + F: Fn(&UserId) -> bool + Send + Sync, + { + let key = self + .db + .keyid_key + .get(key_id) + .await + .deserialized::()?; + + let cleaned = clean_signatures(key, sender_user, user_id, allowed_signatures)?; + let raw_value = serde_json::value::to_raw_value(&cleaned)?; + Ok(Raw::from_json(raw_value)) + } + + pub async fn get_master_key( + &self, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &F, + ) -> Result> + where + F: Fn(&UserId) -> bool + Send + Sync, + { + let key_id = self.db.userid_masterkeyid.get(user_id).await?; + + self.get_key(&key_id, sender_user, user_id, allowed_signatures) + .await } - pub fn get_self_signing_key( - &self, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool, - ) -> Result>> { - self.db - .get_self_signing_key(sender_user, user_id, allowed_signatures) + pub async fn get_self_signing_key( + &self, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &F, + ) -> Result> + where + F: Fn(&UserId) -> bool + Send + Sync, + { + let key_id = self.db.userid_selfsigningkeyid.get(user_id).await?; + + self.get_key(&key_id, sender_user, user_id, allowed_signatures) + .await } - pub fn get_user_signing_key(&self, user_id: &UserId) -> Result>> { - self.db.get_user_signing_key(user_id) + pub async fn get_user_signing_key(&self, user_id: &UserId) -> Result> { + let key_id = self.db.userid_usersigningkeyid.get(user_id).await?; + + self.db.keyid_key.get(&*key_id).await.deserialized() } - pub fn add_to_device_event( + pub async fn add_to_device_event( &self, sender: &UserId, target_user_id: &UserId, target_device_id: &DeviceId, event_type: &str, content: serde_json::Value, - ) -> Result<()> { - self.db - .add_to_device_event(sender, target_user_id, target_device_id, event_type, content) - } + ) { + let mut key = target_user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(target_device_id.as_bytes()); + key.push(0xFF); + key.extend_from_slice(&self.services.globals.next_count().unwrap().to_be_bytes()); + + let mut json = serde_json::Map::new(); + json.insert("type".to_owned(), event_type.to_owned().into()); + json.insert("sender".to_owned(), sender.to_string().into()); + json.insert("content".to_owned(), content); + + let value = serde_json::to_vec(&json).expect("Map::to_vec always works"); - pub fn get_to_device_events(&self, user_id: &UserId, device_id: &DeviceId) -> Result>> { - self.db.get_to_device_events(user_id, device_id) + self.db.todeviceid_events.insert(&key, &value); } - pub fn remove_to_device_events(&self, user_id: &UserId, device_id: &DeviceId, until: u64) -> Result<()> { - self.db.remove_to_device_events(user_id, device_id, until) + pub fn get_to_device_events<'a>( + &'a self, user_id: &'a UserId, device_id: &'a DeviceId, + ) -> impl Stream> + Send + 'a { + let prefix = (user_id, device_id, Interfix); + self.db + .todeviceid_events + .stream_raw_prefix(&prefix) + .ready_and_then(|(_, val)| serde_json::from_slice(val).map_err(Into::into)) + .ignore_err() } - pub fn update_device_metadata(&self, user_id: &UserId, device_id: &DeviceId, device: &Device) -> Result<()> { - self.db.update_device_metadata(user_id, device_id, device) + pub async fn remove_to_device_events(&self, user_id: &UserId, device_id: &DeviceId, until: u64) { + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xFF); + prefix.extend_from_slice(device_id.as_bytes()); + prefix.push(0xFF); + + let mut last = prefix.clone(); + last.extend_from_slice(&until.to_be_bytes()); + + self.db + .todeviceid_events + .rev_raw_keys_from(&last) // this includes last + .ignore_err() + .ready_take_while(move |key| key.starts_with(&prefix)) + .map(|key| { + let len = key.len(); + let start = len.saturating_sub(size_of::()); + let count = utils::u64_from_u8(&key[start..len]); + (key, count) + }) + .ready_take_while(move |(_, count)| *count <= until) + .ready_for_each(|(key, _)| self.db.todeviceid_events.remove(&key)) + .boxed() + .await; + } + + pub async fn update_device_metadata(&self, user_id: &UserId, device_id: &DeviceId, device: &Device) -> Result<()> { + // Only existing devices should be able to call this, but we shouldn't assert + // either... + let key = (user_id, device_id); + if self.db.userdeviceid_metadata.qry(&key).await.is_err() { + return Err!(Database(error!( + ?user_id, + ?device_id, + "Called update_device_metadata for a non-existent user and/or device" + ))); + } + + increment(&self.db.userid_devicelistversion, user_id.as_bytes()); + + let mut userdeviceid = user_id.as_bytes().to_vec(); + userdeviceid.push(0xFF); + userdeviceid.extend_from_slice(device_id.as_bytes()); + self.db.userdeviceid_metadata.insert( + &userdeviceid, + &serde_json::to_vec(device).expect("Device::to_string always works"), + ); + + Ok(()) } /// Get device metadata. - pub fn get_device_metadata(&self, user_id: &UserId, device_id: &DeviceId) -> Result> { - self.db.get_device_metadata(user_id, device_id) + pub async fn get_device_metadata(&self, user_id: &UserId, device_id: &DeviceId) -> Result { + self.db + .userdeviceid_metadata + .qry(&(user_id, device_id)) + .await + .deserialized() } - pub fn get_devicelist_version(&self, user_id: &UserId) -> Result> { - self.db.get_devicelist_version(user_id) + pub async fn get_devicelist_version(&self, user_id: &UserId) -> Result { + self.db + .userid_devicelistversion + .get(user_id) + .await + .deserialized() } - pub fn all_devices_metadata<'a>(&'a self, user_id: &UserId) -> impl Iterator> + 'a { - self.db.all_devices_metadata(user_id) + pub fn all_devices_metadata<'a>(&'a self, user_id: &'a UserId) -> impl Stream + Send + 'a { + self.db + .userdeviceid_metadata + .stream_raw_prefix(&(user_id, Interfix)) + .ready_and_then(|(_, val)| serde_json::from_slice::(val).map_err(Into::into)) + .ignore_err() } - /// Deactivate account - pub fn deactivate_account(&self, user_id: &UserId) -> Result<()> { - // Remove all associated devices - for device_id in self.all_device_ids(user_id) { - self.remove_device(user_id, &device_id?)?; - } + /// Creates a new sync filter. Returns the filter id. + pub fn create_filter(&self, user_id: &UserId, filter: &FilterDefinition) -> String { + let filter_id = utils::random_string(4); - // Set the password to "" to indicate a deactivated account. Hashes will never - // result in an empty string, so the user will not be able to log in again. - // Systems like changing the password without logging in should check if the - // account is deactivated. - self.db.set_password(user_id, None)?; + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(filter_id.as_bytes()); - // TODO: Unhook 3PID - Ok(()) - } + self.db + .userfilterid_filter + .insert(&key, &serde_json::to_vec(&filter).expect("filter is valid json")); - /// Creates a new sync filter. Returns the filter id. - pub fn create_filter(&self, user_id: &UserId, filter: &FilterDefinition) -> Result { - self.db.create_filter(user_id, filter) + filter_id } - pub fn get_filter(&self, user_id: &UserId, filter_id: &str) -> Result> { - self.db.get_filter(user_id, filter_id) + pub async fn get_filter(&self, user_id: &UserId, filter_id: &str) -> Result { + self.db + .userfilterid_filter + .qry(&(user_id, filter_id)) + .await + .deserialized() } /// Creates an OpenID token, which can be used to prove that a user has /// access to an account (primarily for integrations) pub fn create_openid_token(&self, user_id: &UserId, token: &str) -> Result { - self.db.create_openid_token(user_id, token) + use std::num::Saturating as Sat; + + let expires_in = self.services.server.config.openid_token_ttl; + let expires_at = Sat(utils::millis_since_unix_epoch()) + Sat(expires_in) * Sat(1000); + + let mut value = expires_at.0.to_be_bytes().to_vec(); + value.extend_from_slice(user_id.as_bytes()); + + self.db + .openidtoken_expiresatuserid + .insert(token.as_bytes(), value.as_slice()); + + Ok(expires_in) } /// Find out which user an OpenID access token belongs to. - pub fn find_from_openid_token(&self, token: &str) -> Result { self.db.find_from_openid_token(token) } + pub async fn find_from_openid_token(&self, token: &str) -> Result { + let Ok(value) = self.db.openidtoken_expiresatuserid.get(token).await else { + return Err!(Request(Unauthorized("OpenID token is unrecognised"))); + }; + + let (expires_at_bytes, user_bytes) = value.split_at(0_u64.to_be_bytes().len()); + let expires_at = u64::from_be_bytes( + expires_at_bytes + .try_into() + .map_err(|e| err!(Database("expires_at in openid_userid is invalid u64. {e}")))?, + ); + + if expires_at < utils::millis_since_unix_epoch() { + debug_warn!("OpenID token is expired, removing"); + self.db.openidtoken_expiresatuserid.remove(token.as_bytes()); + + return Err!(Request(Unauthorized("OpenID token is expired"))); + } + + let user_string = utils::string_from_bytes(user_bytes) + .map_err(|e| err!(Database("User ID in openid_userid is invalid unicode. {e}")))?; + + UserId::parse(user_string).map_err(|e| err!(Database("User ID in openid_userid is invalid. {e}"))) + } + + /// Gets a specific user profile key + pub async fn profile_key(&self, user_id: &UserId, profile_key: &str) -> Result { + let key = (user_id, profile_key); + self.db + .useridprofilekey_value + .qry(&key) + .await + .deserialized() + } + + /// Gets all the user's profile keys and values in an iterator + pub fn all_profile_keys<'a>( + &'a self, user_id: &'a UserId, + ) -> impl Stream + 'a + Send { + type KeyVal = ((Ignore, String), serde_json::Value); + + let prefix = (user_id, Interfix); + self.db + .useridprofilekey_value + .stream_prefix(&prefix) + .ignore_err() + .map(|((_, key), val): KeyVal| (key, val)) + } + + /// Sets a new profile key value, removes the key if value is None + pub fn set_profile_key(&self, user_id: &UserId, profile_key: &str, profile_key_value: Option) { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(profile_key.as_bytes()); + + // TODO: insert to the stable MSC4175 key when it's stable + if let Some(value) = profile_key_value { + let value = serde_json::to_vec(&value).unwrap(); + + self.db.useridprofilekey_value.insert(&key, &value); + } else { + self.db.useridprofilekey_value.remove(&key); + } + } + + /// Get the timezone of a user. + pub async fn timezone(&self, user_id: &UserId) -> Result { + // TODO: transparently migrate unstable key usage to the stable key once MSC4133 + // and MSC4175 are stable, likely a remove/insert in this block. + + // first check the unstable prefix then check the stable prefix + let unstable_key = (user_id, "us.cloke.msc4175.tz"); + let stable_key = (user_id, "m.tz"); + self.db + .useridprofilekey_value + .qry(&unstable_key) + .or_else(|_| self.db.useridprofilekey_value.qry(&stable_key)) + .await + .deserialized() + } + + /// Sets a new timezone or removes it if timezone is None. + pub fn set_timezone(&self, user_id: &UserId, timezone: Option) { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(b"us.cloke.msc4175.tz"); + + // TODO: insert to the stable MSC4175 key when it's stable + if let Some(timezone) = timezone { + self.db + .useridprofilekey_value + .insert(&key, timezone.as_bytes()); + } else { + self.db.useridprofilekey_value.remove(&key); + } + } +} + +pub fn parse_master_key(user_id: &UserId, master_key: &Raw) -> Result<(Vec, CrossSigningKey)> { + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xFF); + + let master_key = master_key + .deserialize() + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid master key"))?; + let mut master_key_ids = master_key.keys.values(); + let master_key_id = master_key_ids + .next() + .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Master key contained no key."))?; + if master_key_ids.next().is_some() { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Master key contained more than one key.", + )); + } + let mut master_key_key = prefix.clone(); + master_key_key.extend_from_slice(master_key_id.as_bytes()); + Ok((master_key_key, master_key)) } /// Ensure that a user only sees signatures from themselves and the target user -pub fn clean_signatures bool>( - cross_signing_key: &mut serde_json::Value, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: F, -) -> Result<(), Error> { +fn clean_signatures( + mut cross_signing_key: serde_json::Value, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &F, +) -> Result +where + F: Fn(&UserId) -> bool + Send + Sync, +{ if let Some(signatures) = cross_signing_key .get_mut("signatures") .and_then(|v| v.as_object_mut()) @@ -563,5 +997,12 @@ pub fn clean_signatures bool>( } } - Ok(()) + Ok(cross_signing_key) +} + +//TODO: this is an ABA +fn increment(db: &Arc, key: &[u8]) { + let old = db.get_blocking(key); + let new = utils::increment(old.ok().as_deref()); + db.insert(key, &new); }