diff --git a/Cargo.lock b/Cargo.lock index 2abd3976..4c23ae65 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -136,6 +136,7 @@ dependencies = [ "serde_json", "thiserror", "tokio", + "trait-variant", "wasm-bindgen-test", ] @@ -165,6 +166,7 @@ dependencies = [ "serde_json", "thiserror", "tokio", + "trait-variant", "wasm-bindgen-test", ] @@ -279,6 +281,7 @@ dependencies = [ "thiserror", "tokio", "toml", + "trait-variant", "unicode-segmentation", ] @@ -2340,6 +2343,17 @@ dependencies = [ "tracing", ] +[[package]] +name = "trait-variant" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70977707304198400eb4835a78f6a9f928bf41bba420deb8fdb175cd965d77a7" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.71", +] + [[package]] name = "try-lock" version = "0.2.5" diff --git a/Cargo.toml b/Cargo.toml index 6a848b77..3e989e1b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -72,3 +72,6 @@ mockito = "1.4" # WebAssembly wasm-bindgen-test = "0.3.41" bumpalo = "~3.14.0" + +# Code generation +trait-variant = "0.1.2" \ No newline at end of file diff --git a/atrium-api/Cargo.toml b/atrium-api/Cargo.toml index f82c8ce9..10d9cd91 100644 --- a/atrium-api/Cargo.toml +++ b/atrium-api/Cargo.toml @@ -23,6 +23,7 @@ serde_bytes.workspace = true serde_json.workspace = true thiserror.workspace = true tokio = { workspace = true, optional = true } +trait-variant.workspace = true [features] default = ["agent", "bluesky"] diff --git a/atrium-api/src/agent/store.rs b/atrium-api/src/agent/store.rs index b79a5aef..22bdcb37 100644 --- a/atrium-api/src/agent/store.rs +++ b/atrium-api/src/agent/store.rs @@ -5,11 +5,12 @@ use std::future::Future; pub use self::memory::MemorySessionStore; pub(crate) use super::Session; +#[cfg_attr(not(target_arch = "wasm32"), trait_variant::make(Send))] pub trait SessionStore { #[must_use] - fn get_session(&self) -> impl Future> + Send; + fn get_session(&self) -> impl Future>; #[must_use] - fn set_session(&self, session: Session) -> impl Future + Send; + fn set_session(&self, session: Session) -> impl Future; #[must_use] - fn clear_session(&self) -> impl Future + Send; + fn clear_session(&self) -> impl Future; } diff --git a/atrium-xrpc/Cargo.toml b/atrium-xrpc/Cargo.toml index 1daba4d9..e8c551d3 100644 --- a/atrium-xrpc/Cargo.toml +++ b/atrium-xrpc/Cargo.toml @@ -17,6 +17,7 @@ serde = { workspace = true, features = ["derive"] } serde_html_form.workspace = true serde_json.workspace = true thiserror.workspace = true +trait-variant.workspace = true [dev-dependencies] tokio = { workspace = true, features = ["macros", "rt"] } diff --git a/atrium-xrpc/src/traits.rs b/atrium-xrpc/src/traits.rs index 4726aa58..f04e3176 100644 --- a/atrium-xrpc/src/traits.rs +++ b/atrium-xrpc/src/traits.rs @@ -8,12 +8,18 @@ use std::fmt::Debug; use std::future::Future; /// An abstract HTTP client. +#[cfg_attr(not(target_arch = "wasm32"), trait_variant::make(Send))] pub trait HttpClient { /// Send an HTTP request and return the response. fn send_http( &self, request: Request>, - ) -> impl Future>, Box>> + Send; + ) -> impl Future< + Output = core::result::Result< + Response>, + Box, + >, + >; } type XrpcResult = core::result::Result, self::Error>; @@ -22,88 +28,120 @@ type XrpcResult = core::result::Result, self::Error String; /// Get the authentication token to use `Authorization` header. #[allow(unused_variables)] - fn authentication_token(&self, is_refresh: bool) -> impl Future> + Send { + fn authentication_token(&self, is_refresh: bool) -> impl Future> { async { None } } /// Get the `atproto-proxy` header. - fn atproto_proxy_header(&self) -> impl Future> + Send { + fn atproto_proxy_header(&self) -> impl Future> { async { None } } /// Get the `atproto-accept-labelers` header. - fn atproto_accept_labelers_header(&self) -> impl Future>> + Send { + fn atproto_accept_labelers_header(&self) -> impl Future>> { async { None } } /// Send an XRPC request and return the response. - fn send_xrpc(&self, request: &XrpcRequest) -> impl Future> + Send + #[cfg(not(target_arch = "wasm32"))] + fn send_xrpc( + &self, + request: &XrpcRequest, + ) -> impl Future> where P: Serialize + Send + Sync, I: Serialize + Send + Sync, O: DeserializeOwned + Send + Sync, E: DeserializeOwned + Send + Sync + Debug, + // This code is duplicated because of this trait bound. + // `Self` has to be `Sync` for `Future` to be `Send`. Self: Sync, { - async { - let mut uri = format!("{}/xrpc/{}", self.base_uri(), request.nsid); - // Query parameters - if let Some(p) = &request.parameters { - serde_html_form::to_string(p).map(|qs| { - uri += "?"; - uri += &qs; - })?; - }; - let mut builder = Request::builder().method(&request.method).uri(&uri); - // Headers - if let Some(encoding) = &request.encoding { - builder = builder.header(Header::ContentType, encoding); - } - if let Some(token) = self - .authentication_token( - request.method == Method::POST && request.nsid == NSID_REFRESH_SESSION, - ) - .await - { - builder = builder.header(Header::Authorization, format!("Bearer {}", token)); - } - if let Some(proxy) = self.atproto_proxy_header().await { - builder = builder.header(Header::AtprotoProxy, proxy); - } - if let Some(accept_labelers) = self.atproto_accept_labelers_header().await { - builder = builder.header(Header::AtprotoAcceptLabelers, accept_labelers.join(", ")); - } - // Body - let body = if let Some(input) = &request.input { - match input { - InputDataOrBytes::Data(data) => serde_json::to_vec(&data)?, - InputDataOrBytes::Bytes(bytes) => bytes.clone(), - } - } else { - Vec::new() - }; - // Send - let (parts, body) = - self.send_http(builder.body(body)?).await.map_err(Error::HttpClient)?.into_parts(); - if parts.status.is_success() { - if parts - .headers - .get(http::header::CONTENT_TYPE) - .and_then(|value| value.to_str().ok()) - .map_or(false, |content_type| content_type.starts_with("application/json")) - { - Ok(OutputDataOrBytes::Data(serde_json::from_slice(&body)?)) - } else { - Ok(OutputDataOrBytes::Bytes(body)) - } - } else { - Err(Error::XrpcResponse(XrpcError { - status: parts.status, - error: serde_json::from_slice::>(&body).ok(), - })) - } + send_xrpc(self, request) + } + #[cfg(target_arch = "wasm32")] + fn send_xrpc( + &self, + request: &XrpcRequest, + ) -> impl Future> + where + P: Serialize + Send + Sync, + I: Serialize + Send + Sync, + O: DeserializeOwned + Send + Sync, + E: DeserializeOwned + Send + Sync + Debug, + { + send_xrpc(self, request) + } +} + +#[inline(always)] +async fn send_xrpc( + client: &C, + request: &XrpcRequest, +) -> XrpcResult +where + P: Serialize + Send + Sync, + I: Serialize + Send + Sync, + O: DeserializeOwned + Send + Sync, + E: DeserializeOwned + Send + Sync + Debug, +{ + let mut uri = format!("{}/xrpc/{}", client.base_uri(), request.nsid); + // Query parameters + if let Some(p) = &request.parameters { + serde_html_form::to_string(p).map(|qs| { + uri += "?"; + uri += &qs; + })?; + }; + let mut builder = Request::builder().method(&request.method).uri(&uri); + // Headers + if let Some(encoding) = &request.encoding { + builder = builder.header(Header::ContentType, encoding); + } + if let Some(token) = client + .authentication_token( + request.method == Method::POST && request.nsid == NSID_REFRESH_SESSION, + ) + .await + { + builder = builder.header(Header::Authorization, format!("Bearer {}", token)); + } + if let Some(proxy) = client.atproto_proxy_header().await { + builder = builder.header(Header::AtprotoProxy, proxy); + } + if let Some(accept_labelers) = client.atproto_accept_labelers_header().await { + builder = builder.header(Header::AtprotoAcceptLabelers, accept_labelers.join(", ")); + } + // Body + let body = if let Some(input) = &request.input { + match input { + InputDataOrBytes::Data(data) => serde_json::to_vec(&data)?, + InputDataOrBytes::Bytes(bytes) => bytes.clone(), + } + } else { + Vec::new() + }; + // Send + let (parts, body) = + client.send_http(builder.body(body)?).await.map_err(Error::HttpClient)?.into_parts(); + if parts.status.is_success() { + if parts + .headers + .get(http::header::CONTENT_TYPE) + .and_then(|value| value.to_str().ok()) + .map_or(false, |content_type| content_type.starts_with("application/json")) + { + Ok(OutputDataOrBytes::Data(serde_json::from_slice(&body)?)) + } else { + Ok(OutputDataOrBytes::Bytes(body)) } + } else { + Err(Error::XrpcResponse(XrpcError { + status: parts.status, + error: serde_json::from_slice::>(&body).ok(), + })) } } diff --git a/bsky-sdk/Cargo.toml b/bsky-sdk/Cargo.toml index e0184a11..1761eaac 100644 --- a/bsky-sdk/Cargo.toml +++ b/bsky-sdk/Cargo.toml @@ -23,6 +23,7 @@ serde_json.workspace = true thiserror.workspace = true toml = { version = "0.8.13", optional = true } unicode-segmentation = { version = "1.11.0", optional = true } +trait-variant.workspace = true [dev-dependencies] ipld-core.workspace = true diff --git a/bsky-sdk/src/agent/config.rs b/bsky-sdk/src/agent/config.rs index 85454ac3..a804e729 100644 --- a/bsky-sdk/src/agent/config.rs +++ b/bsky-sdk/src/agent/config.rs @@ -51,7 +51,9 @@ pub trait Loader { /// Loads the configuration data. fn load( &self, - ) -> impl Future>> + Send; + ) -> impl Future< + Output = core::result::Result>, + > + Send; } /// The trait for saving configuration data. @@ -60,5 +62,7 @@ pub trait Saver { fn save( &self, config: &Config, - ) -> impl Future>> + Send; + ) -> impl Future< + Output = core::result::Result<(), Box>, + > + Send; } diff --git a/bsky-sdk/src/record.rs b/bsky-sdk/src/record.rs index 3688b83b..81776e6a 100644 --- a/bsky-sdk/src/record.rs +++ b/bsky-sdk/src/record.rs @@ -12,6 +12,7 @@ use atrium_api::com::atproto::repo::{ use atrium_api::types::{Collection, LimitedNonZeroU8, TryIntoUnknown}; use atrium_api::xrpc::XrpcClient; +#[cfg_attr(not(target_arch = "wasm32"), trait_variant::make(Send))] pub trait Record where T: XrpcClient + Send + Sync, @@ -21,11 +22,22 @@ where agent: &BskyAgent, cursor: Option, limit: Option>, - ) -> impl Future> + Send; - fn get(agent: &BskyAgent, rkey: String) -> impl Future> + Send; - fn put(self, agent: &BskyAgent, rkey: String) -> impl Future> + Send; - fn create(self, agent: &BskyAgent) -> impl Future> + Send; - fn delete(agent: &BskyAgent, rkey: String) -> impl Future> + Send; + ) -> impl Future>; + fn get( + agent: &BskyAgent, + rkey: String, + ) -> impl Future>; + fn put( + self, + agent: &BskyAgent, + rkey: String, + ) -> impl Future>; + fn create(self, agent: &BskyAgent) + -> impl Future>; + fn delete( + agent: &BskyAgent, + rkey: String, + ) -> impl Future>; } macro_rules! record_impl {