Skip to content

Commit

Permalink
Relaxing back the Send trait on return types for wasm
Browse files Browse the repository at this point in the history
  • Loading branch information
oestradiol committed Sep 19, 2024
1 parent 078241f commit 4d45c41
Show file tree
Hide file tree
Showing 9 changed files with 147 additions and 72 deletions.
14 changes: 14 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,6 @@ mockito = "1.4"
# WebAssembly
wasm-bindgen-test = "0.3.41"
bumpalo = "~3.14.0"

# Code generation
trait-variant = "0.1.2"
1 change: 1 addition & 0 deletions atrium-api/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ serde_bytes.workspace = true
serde_json.workspace = true
thiserror.workspace = true
tokio = { workspace = true, optional = true }
trait-variant.workspace = true

[features]
default = ["agent", "bluesky"]
Expand Down
7 changes: 4 additions & 3 deletions atrium-api/src/agent/store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@ use std::future::Future;
pub use self::memory::MemorySessionStore;
pub(crate) use super::Session;

#[cfg_attr(not(target_arch = "wasm32"), trait_variant::make(Send))]
pub trait SessionStore {
#[must_use]
fn get_session(&self) -> impl Future<Output = Option<Session>> + Send;
fn get_session(&self) -> impl Future<Output = Option<Session>>;
#[must_use]
fn set_session(&self, session: Session) -> impl Future<Output = ()> + Send;
fn set_session(&self, session: Session) -> impl Future<Output = ()>;
#[must_use]
fn clear_session(&self) -> impl Future<Output = ()> + Send;
fn clear_session(&self) -> impl Future<Output = ()>;
}
1 change: 1 addition & 0 deletions atrium-xrpc/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ serde = { workspace = true, features = ["derive"] }
serde_html_form.workspace = true
serde_json.workspace = true
thiserror.workspace = true
trait-variant.workspace = true

[dev-dependencies]
tokio = { workspace = true, features = ["macros", "rt"] }
Expand Down
162 changes: 100 additions & 62 deletions atrium-xrpc/src/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,18 @@ use std::fmt::Debug;
use std::future::Future;

/// An abstract HTTP client.
#[cfg_attr(not(target_arch = "wasm32"), trait_variant::make(Send))]
pub trait HttpClient {
/// Send an HTTP request and return the response.
fn send_http(
&self,
request: Request<Vec<u8>>,
) -> impl Future<Output = core::result::Result<Response<Vec<u8>>, Box<dyn std::error::Error + Send + Sync + 'static>>> + Send;
) -> impl Future<
Output = core::result::Result<
Response<Vec<u8>>,
Box<dyn std::error::Error + Send + Sync + 'static>,
>,
>;
}

type XrpcResult<O, E> = core::result::Result<OutputDataOrBytes<O>, self::Error<E>>;
Expand All @@ -22,88 +28,120 @@ type XrpcResult<O, E> = core::result::Result<OutputDataOrBytes<O>, self::Error<E
///
/// [`send_xrpc()`](XrpcClient::send_xrpc) method has a default implementation,
/// which wraps the [`HttpClient::send_http()`]` method to handle input and output as an XRPC Request.
#[cfg_attr(not(target_arch = "wasm32"), trait_variant::make(Send))]
pub trait XrpcClient: HttpClient {
/// The base URI of the XRPC server.
fn base_uri(&self) -> String;
/// Get the authentication token to use `Authorization` header.
#[allow(unused_variables)]
fn authentication_token(&self, is_refresh: bool) -> impl Future<Output = Option<String>> + Send {
fn authentication_token(&self, is_refresh: bool) -> impl Future<Output = Option<String>> {
async { None }
}
/// Get the `atproto-proxy` header.
fn atproto_proxy_header(&self) -> impl Future<Output = Option<String>> + Send {
fn atproto_proxy_header(&self) -> impl Future<Output = Option<String>> {
async { None }
}
/// Get the `atproto-accept-labelers` header.
fn atproto_accept_labelers_header(&self) -> impl Future<Output = Option<Vec<String>>> + Send {
fn atproto_accept_labelers_header(&self) -> impl Future<Output = Option<Vec<String>>> {
async { None }
}
/// Send an XRPC request and return the response.
fn send_xrpc<P, I, O, E>(&self, request: &XrpcRequest<P, I>) -> impl Future<Output = XrpcResult<O, E>> + Send
#[cfg(not(target_arch = "wasm32"))]
fn send_xrpc<P, I, O, E>(
&self,
request: &XrpcRequest<P, I>,
) -> impl Future<Output = XrpcResult<O, E>>
where
P: Serialize + Send + Sync,
I: Serialize + Send + Sync,
O: DeserializeOwned + Send + Sync,
E: DeserializeOwned + Send + Sync + Debug,
// This code is duplicated because of this trait bound.
// `Self` has to be `Sync` for `Future` to be `Send`.
Self: Sync,
{
async {
let mut uri = format!("{}/xrpc/{}", self.base_uri(), request.nsid);
// Query parameters
if let Some(p) = &request.parameters {
serde_html_form::to_string(p).map(|qs| {
uri += "?";
uri += &qs;
})?;
};
let mut builder = Request::builder().method(&request.method).uri(&uri);
// Headers
if let Some(encoding) = &request.encoding {
builder = builder.header(Header::ContentType, encoding);
}
if let Some(token) = self
.authentication_token(
request.method == Method::POST && request.nsid == NSID_REFRESH_SESSION,
)
.await
{
builder = builder.header(Header::Authorization, format!("Bearer {}", token));
}
if let Some(proxy) = self.atproto_proxy_header().await {
builder = builder.header(Header::AtprotoProxy, proxy);
}
if let Some(accept_labelers) = self.atproto_accept_labelers_header().await {
builder = builder.header(Header::AtprotoAcceptLabelers, accept_labelers.join(", "));
}
// Body
let body = if let Some(input) = &request.input {
match input {
InputDataOrBytes::Data(data) => serde_json::to_vec(&data)?,
InputDataOrBytes::Bytes(bytes) => bytes.clone(),
}
} else {
Vec::new()
};
// Send
let (parts, body) =
self.send_http(builder.body(body)?).await.map_err(Error::HttpClient)?.into_parts();
if parts.status.is_success() {
if parts
.headers
.get(http::header::CONTENT_TYPE)
.and_then(|value| value.to_str().ok())
.map_or(false, |content_type| content_type.starts_with("application/json"))
{
Ok(OutputDataOrBytes::Data(serde_json::from_slice(&body)?))
} else {
Ok(OutputDataOrBytes::Bytes(body))
}
} else {
Err(Error::XrpcResponse(XrpcError {
status: parts.status,
error: serde_json::from_slice::<XrpcErrorKind<E>>(&body).ok(),
}))
}
send_xrpc(self, request)
}
#[cfg(target_arch = "wasm32")]
fn send_xrpc<P, I, O, E>(
&self,
request: &XrpcRequest<P, I>,
) -> impl Future<Output = XrpcResult<O, E>>
where
P: Serialize + Send + Sync,
I: Serialize + Send + Sync,
O: DeserializeOwned + Send + Sync,
E: DeserializeOwned + Send + Sync + Debug,
{
send_xrpc(self, request)
}
}

#[inline(always)]
async fn send_xrpc<P, I, O, E, C: XrpcClient + ?Sized>(
client: &C,
request: &XrpcRequest<P, I>,
) -> XrpcResult<O, E>
where
P: Serialize + Send + Sync,
I: Serialize + Send + Sync,
O: DeserializeOwned + Send + Sync,
E: DeserializeOwned + Send + Sync + Debug,
{
let mut uri = format!("{}/xrpc/{}", client.base_uri(), request.nsid);
// Query parameters
if let Some(p) = &request.parameters {
serde_html_form::to_string(p).map(|qs| {
uri += "?";
uri += &qs;
})?;
};
let mut builder = Request::builder().method(&request.method).uri(&uri);
// Headers
if let Some(encoding) = &request.encoding {
builder = builder.header(Header::ContentType, encoding);
}
if let Some(token) = client
.authentication_token(
request.method == Method::POST && request.nsid == NSID_REFRESH_SESSION,
)
.await
{
builder = builder.header(Header::Authorization, format!("Bearer {}", token));
}
if let Some(proxy) = client.atproto_proxy_header().await {
builder = builder.header(Header::AtprotoProxy, proxy);
}
if let Some(accept_labelers) = client.atproto_accept_labelers_header().await {
builder = builder.header(Header::AtprotoAcceptLabelers, accept_labelers.join(", "));
}
// Body
let body = if let Some(input) = &request.input {
match input {
InputDataOrBytes::Data(data) => serde_json::to_vec(&data)?,
InputDataOrBytes::Bytes(bytes) => bytes.clone(),
}
} else {
Vec::new()
};
// Send
let (parts, body) =
client.send_http(builder.body(body)?).await.map_err(Error::HttpClient)?.into_parts();
if parts.status.is_success() {
if parts
.headers
.get(http::header::CONTENT_TYPE)
.and_then(|value| value.to_str().ok())
.map_or(false, |content_type| content_type.starts_with("application/json"))
{
Ok(OutputDataOrBytes::Data(serde_json::from_slice(&body)?))
} else {
Ok(OutputDataOrBytes::Bytes(body))
}
} else {
Err(Error::XrpcResponse(XrpcError {
status: parts.status,
error: serde_json::from_slice::<XrpcErrorKind<E>>(&body).ok(),
}))
}
}
1 change: 1 addition & 0 deletions bsky-sdk/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ serde_json.workspace = true
thiserror.workspace = true
toml = { version = "0.8.13", optional = true }
unicode-segmentation = { version = "1.11.0", optional = true }
trait-variant.workspace = true

[dev-dependencies]
ipld-core.workspace = true
Expand Down
8 changes: 6 additions & 2 deletions bsky-sdk/src/agent/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ pub trait Loader {
/// Loads the configuration data.
fn load(
&self,
) -> impl Future<Output = core::result::Result<Config, Box<dyn std::error::Error + Send + Sync + 'static>>> + Send;
) -> impl Future<
Output = core::result::Result<Config, Box<dyn std::error::Error + Send + Sync + 'static>>,
> + Send;
}

/// The trait for saving configuration data.
Expand All @@ -60,5 +62,7 @@ pub trait Saver {
fn save(
&self,
config: &Config,
) -> impl Future<Output = core::result::Result<(), Box<dyn std::error::Error + Send + Sync + 'static>>> + Send;
) -> impl Future<
Output = core::result::Result<(), Box<dyn std::error::Error + Send + Sync + 'static>>,
> + Send;
}
22 changes: 17 additions & 5 deletions bsky-sdk/src/record.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use atrium_api::com::atproto::repo::{
use atrium_api::types::{Collection, LimitedNonZeroU8, TryIntoUnknown};
use atrium_api::xrpc::XrpcClient;

#[cfg_attr(not(target_arch = "wasm32"), trait_variant::make(Send))]
pub trait Record<T, S>
where
T: XrpcClient + Send + Sync,
Expand All @@ -21,11 +22,22 @@ where
agent: &BskyAgent<T, S>,
cursor: Option<String>,
limit: Option<LimitedNonZeroU8<100u8>>,
) -> impl Future<Output = Result<list_records::Output>> + Send;
fn get(agent: &BskyAgent<T, S>, rkey: String) -> impl Future<Output = Result<get_record::Output>> + Send;
fn put(self, agent: &BskyAgent<T, S>, rkey: String) -> impl Future<Output = Result<put_record::Output>> + Send;
fn create(self, agent: &BskyAgent<T, S>) -> impl Future<Output = Result<create_record::Output>> + Send;
fn delete(agent: &BskyAgent<T, S>, rkey: String) -> impl Future<Output = Result<delete_record::Output>> + Send;
) -> impl Future<Output = Result<list_records::Output>>;
fn get(
agent: &BskyAgent<T, S>,
rkey: String,
) -> impl Future<Output = Result<get_record::Output>>;
fn put(
self,
agent: &BskyAgent<T, S>,
rkey: String,
) -> impl Future<Output = Result<put_record::Output>>;
fn create(self, agent: &BskyAgent<T, S>)
-> impl Future<Output = Result<create_record::Output>>;
fn delete(
agent: &BskyAgent<T, S>,
rkey: String,
) -> impl Future<Output = Result<delete_record::Output>>;
}

macro_rules! record_impl {
Expand Down

0 comments on commit 4d45c41

Please sign in to comment.