From 21bae4e8cdf26b129572fb56be01d277d95620bb Mon Sep 17 00:00:00 2001 From: Ivan Reshetnikov Date: Mon, 12 Jun 2023 13:39:50 +0200 Subject: [PATCH] feat: watch rpc types; http client (#22) --- Cargo.toml | 15 +- examples/http_client.rs | 72 ++++ .../{basic_client.rs => websocket_client.rs} | 19 +- relay_client/Cargo.toml | 16 +- relay_client/src/{errors.rs => error.rs} | 40 +- relay_client/src/http.rs | 292 +++++++++++++++ relay_client/src/lib.rs | 147 +++++--- relay_client/src/{client.rs => websocket.rs} | 52 ++- .../src/{client => websocket}/connection.rs | 27 +- .../src/{client => websocket}/fetch.rs | 3 +- .../src/{client => websocket}/inbound.rs | 0 .../src/{client => websocket}/outbound.rs | 65 +--- .../src/{client => websocket}/stream.rs | 37 +- relay_rpc/src/auth.rs | 232 +----------- relay_rpc/src/auth/did.rs | 6 +- relay_rpc/src/auth/tests.rs | 131 ------- relay_rpc/src/domain.rs | 119 +++++- relay_rpc/src/domain/tests.rs | 27 -- relay_rpc/src/jwt.rs | 348 ++++++++++++++++++ relay_rpc/src/lib.rs | 2 + relay_rpc/src/rpc.rs | 175 ++++++++- relay_rpc/src/rpc/tests.rs | 127 ++++++- relay_rpc/src/rpc/watch.rs | 237 ++++++++++++ relay_rpc/src/serde_helpers.rs | 56 +++ 24 files changed, 1667 insertions(+), 578 deletions(-) create mode 100644 examples/http_client.rs rename examples/{basic_client.rs => websocket_client.rs} (85%) rename relay_client/src/{errors.rs => error.rs} (58%) create mode 100644 relay_client/src/http.rs rename relay_client/src/{client.rs => websocket.rs} (83%) rename relay_client/src/{client => websocket}/connection.rs (84%) rename relay_client/src/{client => websocket}/fetch.rs (98%) rename relay_client/src/{client => websocket}/inbound.rs (100%) rename relay_client/src/{client => websocket}/outbound.rs (63%) rename relay_client/src/{client => websocket}/stream.rs (90%) delete mode 100644 relay_rpc/src/auth/tests.rs delete mode 100644 relay_rpc/src/domain/tests.rs create mode 100644 relay_rpc/src/jwt.rs create mode 100644 relay_rpc/src/rpc/watch.rs create mode 100644 relay_rpc/src/serde_helpers.rs diff --git a/Cargo.toml b/Cargo.toml index f577079..f63f273 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,10 +13,8 @@ members = [ ] [features] -full = [ - "client", - "rpc", -] +default = ["full"] +full = ["client", "rpc"] client = ["dep:relay_client"] rpc = ["dep:relay_rpc"] @@ -28,7 +26,12 @@ relay_rpc = { path = "./relay_rpc", optional = true } anyhow = "1" structopt = { version = "0.3", default-features = false } tokio = { version = "1.22", features = ["full"] } +url = "2.3" + +[[example]] +name = "websocket_client" +required-features = ["client","rpc"] [[example]] -name = "basic_client" -required-features = ["full"] +name = "http_client" +required-features = ["client","rpc"] diff --git a/examples/http_client.rs b/examples/http_client.rs new file mode 100644 index 0000000..02b1ff5 --- /dev/null +++ b/examples/http_client.rs @@ -0,0 +1,72 @@ +use { + relay_client::{http::Client, ConnectionOptions}, + relay_rpc::{ + auth::{ed25519_dalek::Keypair, rand, AuthToken}, + domain::Topic, + }, + std::{sync::Arc, time::Duration}, + structopt::StructOpt, + url::Url, +}; + +#[derive(StructOpt)] +struct Args { + /// Specify HTTP address. + #[structopt(short, long, default_value = "https://relay.walletconnect.com/rpc")] + address: String, + + /// Specify WalletConnect project ID. + #[structopt(short, long, default_value = "3cbaa32f8fbf3cdcc87d27ca1fa68069")] + project_id: String, +} + +fn create_conn_opts(key: &Keypair, address: &str, project_id: &str) -> ConnectionOptions { + let aud = Url::parse(address) + .unwrap() + .origin() + .unicode_serialization(); + + let auth = AuthToken::new("http://example.com") + .aud(aud) + .ttl(Duration::from_secs(60 * 60)) + .as_jwt(key) + .unwrap(); + + ConnectionOptions::new(project_id, auth).with_address(address) +} + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + let args = Args::from_args(); + + let key1 = Keypair::generate(&mut rand::thread_rng()); + let client1 = Client::new(&create_conn_opts(&key1, &args.address, &args.project_id))?; + + let key2 = Keypair::generate(&mut rand::thread_rng()); + let client2 = Client::new(&create_conn_opts(&key2, &args.address, &args.project_id))?; + + let topic = Topic::generate(); + let message: Arc = Arc::from("Hello WalletConnect!"); + + client1 + .publish( + topic.clone(), + message.clone(), + 1100, + Duration::from_secs(30), + ) + .await?; + + println!("[client1] published message with topic: {topic}",); + + tokio::time::sleep(Duration::from_secs(1)).await; + + let messages = client2.fetch(topic).await?.messages; + let message = messages + .get(0) + .ok_or(anyhow::anyhow!("fetch did not return any messages"))?; + + println!("[client2] received message: {}", message.message); + + Ok(()) +} diff --git a/examples/basic_client.rs b/examples/websocket_client.rs similarity index 85% rename from examples/basic_client.rs rename to examples/websocket_client.rs index e31283a..8eb1513 100644 --- a/examples/basic_client.rs +++ b/examples/websocket_client.rs @@ -1,15 +1,12 @@ use { relay_client::{ - Client, - CloseFrame, - ConnectionHandler, + error::Error, + websocket::{Client, CloseFrame, ConnectionHandler, PublishedMessage}, ConnectionOptions, - Error, - PublishedMessage, }, relay_rpc::{ auth::{ed25519_dalek::Keypair, rand, AuthToken}, - domain::{AuthSubject, Topic}, + domain::Topic, }, std::{sync::Arc, time::Duration}, structopt::StructOpt, @@ -64,7 +61,7 @@ impl ConnectionHandler for Handler { fn create_conn_opts(address: &str, project_id: &str) -> ConnectionOptions { let key = Keypair::generate(&mut rand::thread_rng()); - let auth = AuthToken::new(AuthSubject::generate()) + let auth = AuthToken::new("http://example.com") .aud(address) .ttl(Duration::from_secs(60 * 60)) .as_jwt(&key) @@ -79,12 +76,12 @@ async fn main() -> anyhow::Result<()> { let client1 = Client::new(Handler::new("client1")); client1 - .connect(create_conn_opts(&args.address, &args.project_id)) + .connect(&create_conn_opts(&args.address, &args.project_id)) .await?; let client2 = Client::new(Handler::new("client2")); client2 - .connect(create_conn_opts(&args.address, &args.project_id)) + .connect(&create_conn_opts(&args.address, &args.project_id)) .await?; let topic = Topic::generate(); @@ -101,7 +98,9 @@ async fn main() -> anyhow::Result<()> { ) .await?; - println!("[client2] published message with topic: {topic}"); + println!("[client2] published message with topic: {topic}",); + + tokio::time::sleep(Duration::from_millis(500)).await; drop(client1); drop(client2); diff --git a/relay_client/Cargo.toml b/relay_client/Cargo.toml index 78acc7f..cdf5e13 100644 --- a/relay_client/Cargo.toml +++ b/relay_client/Cargo.toml @@ -11,14 +11,20 @@ rustls = ["tokio-tungstenite/rustls-tls-native-roots"] relay_rpc = { path = "../relay_rpc" } futures-util = { version = "0.3", default-features = false, features = ["sink", "std"] } thiserror = "1.0" -tokio = { version = "1.22", features = ["rt", "time", "sync", "macros", "rt-multi-thread"] } -tokio-tungstenite = "0.18" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" serde_qs = "0.10" -futures-channel = "0.3" -tokio-stream = "0.1" -tokio-util = "0.7" pin-project = "1.0" chrono = { version = "0.4", default-features = false, features = ["alloc", "std"] } url = "2.3" +http = "0.2" + +# HTTP client dependencies. +reqwest = { version = "0.11", features = ["json"] } + +# WebSocket client dependencies. +tokio = { version = "1.22", features = ["rt", "time", "sync", "macros", "rt-multi-thread"] } +tokio-tungstenite = "0.18" +futures-channel = "0.3" +tokio-stream = "0.1" +tokio-util = "0.7" diff --git a/relay_client/src/errors.rs b/relay_client/src/error.rs similarity index 58% rename from relay_client/src/errors.rs rename to relay_client/src/error.rs index 3d7b895..a76e984 100644 --- a/relay_client/src/errors.rs +++ b/relay_client/src/error.rs @@ -1,23 +1,5 @@ -pub use tokio_tungstenite::tungstenite::protocol::CloseFrame; - -pub type WsError = tokio_tungstenite::tungstenite::Error; pub type BoxError = Box; -/// Wrapper around the websocket [`CloseFrame`] providing info about the -/// connection closing reason. -#[derive(Debug, Clone)] -pub struct CloseReason(pub Option>); - -impl std::fmt::Display for CloseReason { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - if let Some(frame) = &self.0 { - frame.fmt(f) - } else { - f.write_str("") - } - } -} - /// Errors generated while parsing /// [`ConnectionOptions`][crate::ConnectionOptions] and creating an HTTP request /// for the websocket connection. @@ -33,7 +15,10 @@ pub enum RequestBuildError { Url(#[from] url::ParseError), #[error("Failed to create websocket request: {0}")] - Other(WsError), + WebsocketClient(#[from] crate::websocket::WebsocketClientError), + + #[error("Failed to create HTTP request: {0}")] + HttpClient(#[from] crate::http::HttpClientError), } /// Possible Relay client errors. @@ -42,20 +27,11 @@ pub enum Error { #[error("Failed to build connection request: {0}")] RequestBuilder(#[from] RequestBuildError), - #[error("Failed to connect: {0}")] - ConnectionFailed(WsError), - - #[error("Connection closed: {0}")] - ConnectionClosed(CloseReason), - - #[error("Failed to close connection: {0}")] - ClosingFailed(WsError), - - #[error("Not connected")] - NotConnected, + #[error("Websocket client error: {0}")] + WebsocketClient(#[from] crate::websocket::WebsocketClientError), - #[error("Websocket error: {0}")] - Socket(WsError), + #[error("HTTP client error: {0}")] + HttpClient(#[from] crate::http::HttpClientError), #[error("Internal error: Channel closed")] ChannelClosed, diff --git a/relay_client/src/http.rs b/relay_client/src/http.rs new file mode 100644 index 0000000..799e4e8 --- /dev/null +++ b/relay_client/src/http.rs @@ -0,0 +1,292 @@ +use { + crate::{ + error::{BoxError, Error}, + ConnectionOptions, + MessageIdGenerator, + }, + http::{HeaderMap, StatusCode}, + relay_rpc::{ + auth::ed25519_dalek::Keypair, + domain::{DecodedClientId, SubscriptionId, Topic}, + jwt::{self, JwtError, VerifyableClaims}, + rpc::{self, RequestPayload}, + }, + std::{sync::Arc, time::Duration}, + url::Url, +}; + +pub type TransportError = reqwest::Error; +pub type Response = Result<::Response, Error>; +pub type EmptyResponse = Result<(), Error>; + +#[derive(Debug, thiserror::Error)] +pub enum RequestParamsError { + #[error("Invalid TTL")] + InvalidTtl, +} + +#[derive(Debug, thiserror::Error)] +pub enum HttpClientError { + #[error("HTTP transport error: {0}")] + Transport(#[from] TransportError), + + #[error("Invalid request: {0}")] + InvalidRequest(BoxError), + + #[error("Invalid response")] + InvalidResponse, + + #[error("Invalid HTTP status: {0}")] + InvalidHttpCode(StatusCode), + + #[error("JWT error: {0}")] + Jwt(#[from] JwtError), + + #[error("RPC error: code={} message={}", .0.code, .0.message)] + RpcError(rpc::ErrorData), +} + +#[derive(Debug, Clone)] +pub struct WatchRegisterRequest { + /// Service URL. + pub service_url: String, + /// Webhook URL. + pub webhook_url: String, + /// Watcher type. Either subscriber or publisher. + pub watch_type: rpc::WatchType, + /// Array of message tags to watch. + pub tags: Vec, + /// Array of statuses to watch. + pub statuses: Vec, + /// TTL for the registration. + pub ttl: Duration, +} + +#[derive(Debug, Clone)] +pub struct WatchUnregisterRequest { + /// Service URL. + pub service_url: String, + /// Webhook URL. + pub webhook_url: String, + /// Watcher type. Either subscriber or publisher. + pub watch_type: rpc::WatchType, +} + +/// The Relay HTTP RPC client. +#[derive(Debug, Clone)] +pub struct Client { + client: reqwest::Client, + url: Url, + origin: String, + id_generator: MessageIdGenerator, +} + +impl Client { + pub fn new(opts: &ConnectionOptions) -> Result { + let mut headers = HeaderMap::new(); + opts.update_request_headers(&mut headers)?; + + let client = reqwest::Client::builder() + .default_headers(headers) + .build() + .map_err(HttpClientError::Transport)?; + + let url = opts.as_url()?; + let origin = url.origin().unicode_serialization(); + let id_generator = MessageIdGenerator::new(); + + Ok(Self { + client, + url, + origin, + id_generator, + }) + } + + /// Publishes a message over the network on given topic. + pub async fn publish( + &self, + topic: Topic, + message: impl Into>, + tag: u32, + ttl: Duration, + ) -> EmptyResponse { + let ttl_secs = ttl + .as_secs() + .try_into() + .map_err(|_| HttpClientError::InvalidRequest(RequestParamsError::InvalidTtl.into()))?; + + self.request(rpc::Publish { + topic, + message: message.into(), + ttl_secs, + tag, + prompt: false, + }) + .await + .map(|_| ()) + } + + /// Subscribes on topic to receive messages. + pub async fn subscribe(&self, topic: Topic) -> Response { + self.request(rpc::Subscribe { topic }).await + } + + /// Unsubscribes from a topic. + pub async fn unsubscribe( + &self, + topic: Topic, + subscription_id: SubscriptionId, + ) -> Response { + self.request(rpc::Unsubscribe { + topic, + subscription_id, + }) + .await + } + + /// Fetch mailbox messages for a specific topic. + pub async fn fetch(&self, topic: Topic) -> Response { + self.request(rpc::FetchMessages { topic }).await + } + + /// Registers a webhook to watch messages. + pub async fn watch_register( + &self, + request: WatchRegisterRequest, + keypair: &Keypair, + ) -> Response { + let iat = chrono::Utc::now().timestamp(); + let ttl_sec: i64 = request + .ttl + .as_secs() + .try_into() + .map_err(|err| HttpClientError::InvalidRequest(Box::new(err)))?; + let exp = iat + ttl_sec; + + let claims = rpc::WatchRegisterClaims { + basic: jwt::JwtBasicClaims { + iss: DecodedClientId::from_key(&keypair.public_key()).into(), + aud: self.origin.clone(), + iat, + sub: request.service_url, + exp: Some(exp), + }, + act: rpc::WatchAction::Register, + typ: request.watch_type, + whu: request.webhook_url, + tag: request.tags, + sts: request.statuses, + }; + + let payload = rpc::WatchRegister { + register_auth: claims.encode(keypair).map_err(HttpClientError::Jwt)?, + }; + + self.request(payload).await + } + + /// Unregisters a webhook to watch messages. + pub async fn watch_unregister( + &self, + request: WatchUnregisterRequest, + keypair: &Keypair, + ) -> Response { + let iat = chrono::Utc::now().timestamp(); + + let claims = rpc::WatchUnregisterClaims { + basic: jwt::JwtBasicClaims { + iss: DecodedClientId::from_key(&keypair.public_key()).into(), + aud: self.origin.clone(), + iat, + sub: request.service_url, + exp: None, + }, + act: rpc::WatchAction::Unregister, + typ: request.watch_type, + whu: request.webhook_url, + }; + + let payload = rpc::WatchUnregister { + unregister_auth: claims.encode(keypair).map_err(HttpClientError::Jwt)?, + }; + + self.request(payload).await + } + + /// Subscribes on multiple topics to receive messages. + pub async fn batch_subscribe( + &self, + topics: impl Into>, + ) -> Response { + self.request(rpc::BatchSubscribe { + topics: topics.into(), + }) + .await + } + + /// Unsubscribes from multiple topics. + pub async fn batch_unsubscribe( + &self, + subscriptions: impl Into>, + ) -> Response { + self.request(rpc::BatchUnsubscribe { + subscriptions: subscriptions.into(), + }) + .await + } + + /// Fetch mailbox messages for multiple topics. + pub async fn batch_fetch( + &self, + topics: impl Into>, + ) -> Response { + self.request(rpc::BatchFetchMessages { + topics: topics.into(), + }) + .await + } + + pub(crate) async fn request(&self, payload: T) -> Response + where + T: RequestPayload, + { + let payload = rpc::Payload::Request(rpc::Request { + id: self.id_generator.next(), + jsonrpc: rpc::JSON_RPC_VERSION.clone(), + params: payload.into_params(), + }); + + let result = self + .client + .post(self.url.clone()) + .json(&payload) + .send() + .await + .map_err(HttpClientError::Transport)?; + + let status = result.status(); + + if !status.is_success() { + return Err(HttpClientError::InvalidHttpCode(status).into()); + } + + let response = result + .json::() + .await + .map_err(|_| HttpClientError::InvalidResponse)?; + + match response { + rpc::Payload::Response(rpc::Response::Success(response)) => { + serde_json::from_value(response.result) + .map_err(|_| HttpClientError::InvalidResponse.into()) + } + + rpc::Payload::Response(rpc::Response::Error(response)) => { + Err(HttpClientError::RpcError(response.error).into()) + } + + _ => Err(HttpClientError::InvalidResponse.into()), + } + } +} diff --git a/relay_client/src/lib.rs b/relay_client/src/lib.rs index a0a2ea0..6fccffd 100644 --- a/relay_client/src/lib.rs +++ b/relay_client/src/lib.rs @@ -1,17 +1,24 @@ -pub use {client::*, errors::*}; use { + crate::error::{Error, RequestBuildError}, + ::http::HeaderMap, relay_rpc::{ auth::{SerializedAuthToken, RELAY_WEBSOCKET_ADDRESS}, - domain::ProjectId, + domain::{MessageId, ProjectId}, user_agent::UserAgent, }, serde::Serialize, - tokio_tungstenite::tungstenite::{client::IntoClientRequest, http}, + std::sync::{ + atomic::{AtomicU8, Ordering}, + Arc, + }, url::Url, }; -mod client; -mod errors; +pub mod error; +pub mod http; +pub mod websocket; + +pub type HttpRequest = ::http::Request; /// Relay authorization method. A wrapper around [`SerializedAuthToken`]. #[derive(Debug, Clone)] @@ -70,49 +77,51 @@ impl ConnectionOptions { self } - fn into_request(self) -> Result, Error> { - let ConnectionOptions { - address, - project_id, - auth, - origin, - user_agent, - } = self; - - let query = { - let auth = if let Authorization::Query(auth) = &auth { - Some(auth.to_owned()) + pub fn as_url(&self) -> Result { + #[derive(Serialize)] + #[serde(rename_all = "camelCase")] + struct QueryParams<'a> { + project_id: &'a ProjectId, + auth: Option<&'a SerializedAuthToken>, + ua: Option<&'a UserAgent>, + } + + let query = serde_qs::to_string(&QueryParams { + project_id: &self.project_id, + auth: if let Authorization::Query(auth) = &self.auth { + Some(auth) } else { None - }; - - #[derive(Serialize)] - #[serde(rename_all = "camelCase")] - struct QueryParams { - project_id: ProjectId, - auth: Option, - ua: Option, - } - - let query = QueryParams { - project_id, - auth, - ua: user_agent, - }; - - serde_qs::to_string(&query).map_err(RequestBuildError::Query)? - }; + }, + ua: self.user_agent.as_ref(), + }) + .map_err(RequestBuildError::Query)?; - let mut url = Url::parse(&address).map_err(RequestBuildError::Url)?; + let mut url = Url::parse(&self.address).map_err(RequestBuildError::Url)?; url.set_query(Some(&query)); + Ok(url) + } + + fn as_ws_request(&self) -> Result, RequestBuildError> { + use { + crate::websocket::WebsocketClientError, + tokio_tungstenite::tungstenite::client::IntoClientRequest, + }; + + let url = self.as_url()?; + let mut request = url .into_client_request() - .map_err(RequestBuildError::Other)?; + .map_err(WebsocketClientError::Transport)?; + + self.update_request_headers(request.headers_mut())?; - let headers = request.headers_mut(); + Ok(request) + } - if let Authorization::Header(token) = &auth { + fn update_request_headers(&self, headers: &mut HeaderMap) -> Result<(), RequestBuildError> { + if let Authorization::Header(token) = &self.auth { let value = format!("Bearer {token}") .parse() .map_err(|_| RequestBuildError::Headers)?; @@ -120,12 +129,68 @@ impl ConnectionOptions { headers.append("Authorization", value); } - if let Some(origin) = &origin { + if let Some(origin) = &self.origin { let value = origin.parse().map_err(|_| RequestBuildError::Headers)?; headers.append("Origin", value); } - Ok(request) + Ok(()) + } +} + +/// Generates unique message IDs for use in RPC requests. Uses 56 bits for the +/// timestamp with millisecond precision, with the last 8 bits from a monotonic +/// counter. Capable of producing up to `256000` unique values per second. +#[derive(Debug, Clone)] +pub struct MessageIdGenerator { + next: Arc, +} + +impl MessageIdGenerator { + pub fn new() -> Self { + Self::default() + } + + /// Generates a [`MessageId`]. + pub fn next(&self) -> MessageId { + let next = self.next.fetch_add(1, Ordering::Relaxed) as u64; + let timestamp = chrono::Utc::now().timestamp_millis() as u64; + let id = timestamp << 8 | next; + + MessageId::new(id) + } +} + +impl Default for MessageIdGenerator { + fn default() -> Self { + Self { + next: Arc::new(AtomicU8::new(0)), + } + } +} + +#[cfg(test)] +mod tests { + use { + super::*, + std::{collections::HashSet, hash::Hash}, + }; + + fn elements_unique(iter: T) -> bool + where + T: IntoIterator, + T::Item: Eq + Hash, + { + let mut set = HashSet::new(); + iter.into_iter().all(move |x| set.insert(x)) + } + + #[test] + fn unique_message_ids() { + let gen = MessageIdGenerator::new(); + // N.B. We can produce up to 256 unique values within 1ms. + let values = (0..256).map(move |_| gen.next()).collect::>(); + assert!(elements_unique(values)); } } diff --git a/relay_client/src/client.rs b/relay_client/src/websocket.rs similarity index 83% rename from relay_client/src/client.rs rename to relay_client/src/websocket.rs index f8ae4cb..3359a46 100644 --- a/relay_client/src/client.rs +++ b/relay_client/src/websocket.rs @@ -1,6 +1,6 @@ use { self::connection::{connection_event_loop, ConnectionControl}, - crate::{ConnectionOptions, Error}, + crate::{error::Error, ConnectionOptions}, relay_rpc::{ domain::{SubscriptionId, Topic}, rpc::{ @@ -19,9 +19,49 @@ use { mpsc::{self, UnboundedSender}, oneshot, }, +}; +pub use { + fetch::*, + inbound::*, + outbound::*, + stream::*, tokio_tungstenite::tungstenite::protocol::CloseFrame, }; -pub use {fetch::*, inbound::*, outbound::*, stream::*}; + +pub type TransportError = tokio_tungstenite::tungstenite::Error; + +#[derive(Debug, thiserror::Error)] +pub enum WebsocketClientError { + #[error("Failed to connect: {0}")] + ConnectionFailed(TransportError), + + #[error("Connection closed: {0}")] + ConnectionClosed(CloseReason), + + #[error("Failed to close connection: {0}")] + ClosingFailed(TransportError), + + #[error("Websocket transport error: {0}")] + Transport(TransportError), + + #[error("Not connected")] + NotConnected, +} + +/// Wrapper around the websocket [`CloseFrame`] providing info about the +/// connection closing reason. +#[derive(Debug, Clone)] +pub struct CloseReason(pub Option>); + +impl std::fmt::Display for CloseReason { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if let Some(frame) = &self.0 { + frame.fmt(f) + } else { + f.write_str("") + } + } +} mod connection; mod fetch; @@ -73,7 +113,7 @@ pub trait ConnectionHandler: Send + 'static { fn outbound_error(&mut self, _error: Error) {} } -/// The Relay RPC client. +/// The Relay WebSocket RPC client. /// /// This provides the high-level access to all of the available RPC methods. For /// a lower-level RPC stream see [`ClientStream`](crate::client::ClientStream). @@ -192,13 +232,13 @@ impl Client { } /// Opens a connection to the Relay. - pub async fn connect(&self, opts: ConnectionOptions) -> Result<(), Error> { + pub async fn connect(&self, opts: &ConnectionOptions) -> Result<(), Error> { let (tx, rx) = oneshot::channel(); - let opts = Box::new(opts); + let request = opts.as_ws_request()?; if self .control_tx - .send(ConnectionControl::Connect { opts, tx }) + .send(ConnectionControl::Connect { request, tx }) .is_ok() { rx.await.map_err(|_| Error::ChannelClosed)? diff --git a/relay_client/src/client/connection.rs b/relay_client/src/websocket/connection.rs similarity index 84% rename from relay_client/src/client/connection.rs rename to relay_client/src/websocket/connection.rs index 4186a19..22a2b3d 100644 --- a/relay_client/src/client/connection.rs +++ b/relay_client/src/websocket/connection.rs @@ -2,14 +2,14 @@ use { super::{ outbound::OutboundRequest, stream::{create_stream, ClientStream}, + ConnectionHandler, + TransportError, + WebsocketClientError, }, crate::{ - client::stream::StreamEvent, - ConnectionHandler, - ConnectionOptions, + websocket::{stream::StreamEvent, PublishedMessage}, Error, - PublishedMessage, - WsError, + HttpRequest, }, futures_util::{stream::FusedStream, Stream, StreamExt}, std::{ @@ -21,7 +21,7 @@ use { pub(super) enum ConnectionControl { Connect { - opts: Box, + request: HttpRequest<()>, tx: oneshot::Sender>, }, @@ -45,8 +45,8 @@ pub(super) async fn connection_event_loop( event = control_rx.recv() => { match event { Some(event) => match event { - ConnectionControl::Connect { tx, opts } => { - let result = conn.connect(*opts).await; + ConnectionControl::Connect { request, tx } => { + let result = conn.connect(request).await; if result.is_ok() { handler.connected(); @@ -107,12 +107,12 @@ impl Connection { Self { stream: None } } - async fn connect(&mut self, opts: ConnectionOptions) -> Result<(), Error> { + async fn connect(&mut self, request: HttpRequest<()>) -> Result<(), Error> { if let Some(mut stream) = self.stream.take() { stream.close(None).await?; } - self.stream = Some(create_stream(opts).await?); + self.stream = Some(create_stream(request).await?); Ok(()) } @@ -123,7 +123,7 @@ impl Connection { match stream { Some(mut stream) => stream.close(None).await, - None => Err(Error::ClosingFailed(WsError::AlreadyClosed)), + None => Err(WebsocketClientError::ClosingFailed(TransportError::AlreadyClosed).into()), } } @@ -132,7 +132,10 @@ impl Connection { Some(stream) => stream.send_raw(request), None => { - request.tx.send(Err(Error::NotConnected)).ok(); + request + .tx + .send(Err(WebsocketClientError::NotConnected.into())) + .ok(); } } } diff --git a/relay_client/src/client/fetch.rs b/relay_client/src/websocket/fetch.rs similarity index 98% rename from relay_client/src/client/fetch.rs rename to relay_client/src/websocket/fetch.rs index 223f2f6..6030455 100644 --- a/relay_client/src/client/fetch.rs +++ b/relay_client/src/websocket/fetch.rs @@ -1,5 +1,6 @@ use { - crate::{create_request, Client, Error, ResponseFuture}, + super::{create_request, Client, ResponseFuture}, + crate::Error, futures_util::{FutureExt, Stream}, relay_rpc::{ domain::Topic, diff --git a/relay_client/src/client/inbound.rs b/relay_client/src/websocket/inbound.rs similarity index 100% rename from relay_client/src/client/inbound.rs rename to relay_client/src/websocket/inbound.rs diff --git a/relay_client/src/client/outbound.rs b/relay_client/src/websocket/outbound.rs similarity index 63% rename from relay_client/src/client/outbound.rs rename to relay_client/src/websocket/outbound.rs index 1907fa9..dbbc244 100644 --- a/relay_client/src/client/outbound.rs +++ b/relay_client/src/websocket/outbound.rs @@ -1,18 +1,11 @@ use { crate::Error, pin_project::pin_project, - relay_rpc::{ - domain::MessageId, - rpc::{Params, RequestPayload}, - }, + relay_rpc::rpc::{Params, RequestPayload}, std::{ future::Future, marker::PhantomData, pin::Pin, - sync::{ - atomic::{AtomicU8, Ordering}, - Arc, - }, task::{ready, Context, Poll}, }, tokio::sync::oneshot, @@ -114,59 +107,3 @@ where ResponseFuture::new(rx), ) } - -/// Generates unique message IDs for use in RPC requests. Uses 56 bits for the -/// timestamp with millisecond precision, with the last 8 bits from a monotonic -/// counter. Capable of producing up to `256000` unique values per second. -#[derive(Debug, Clone)] -pub struct MessageIdGenerator { - next: Arc, -} - -impl MessageIdGenerator { - pub fn new() -> Self { - Self::default() - } - - /// Generates a [`MessageId`]. - pub fn next(&self) -> MessageId { - let next = self.next.fetch_add(1, Ordering::Relaxed) as u64; - let timestamp = chrono::Utc::now().timestamp_millis() as u64; - let id = timestamp << 8 | next; - - MessageId::new(id) - } -} - -impl Default for MessageIdGenerator { - fn default() -> Self { - Self { - next: Arc::new(AtomicU8::new(0)), - } - } -} - -#[cfg(test)] -mod tests { - use { - super::*, - std::{collections::HashSet, hash::Hash}, - }; - - fn elements_unique(iter: T) -> bool - where - T: IntoIterator, - T::Item: Eq + Hash, - { - let mut set = HashSet::new(); - iter.into_iter().all(move |x| set.insert(x)) - } - - #[test] - fn unique_message_ids() { - let gen = MessageIdGenerator::new(); - // N.B. We can produce up to 256 unique values within 1ms. - let values = (0..256).map(move |_| gen.next()).collect::>(); - assert!(elements_unique(values)); - } -} diff --git a/relay_client/src/client/stream.rs b/relay_client/src/websocket/stream.rs similarity index 90% rename from relay_client/src/client/stream.rs rename to relay_client/src/websocket/stream.rs index 185c59f..3601b0f 100644 --- a/relay_client/src/client/stream.rs +++ b/relay_client/src/websocket/stream.rs @@ -1,9 +1,12 @@ use { super::{ inbound::InboundRequest, - outbound::{create_request, MessageIdGenerator, OutboundRequest, ResponseFuture}, + outbound::{create_request, OutboundRequest, ResponseFuture}, + CloseReason, + TransportError, + WebsocketClientError, }, - crate::{CloseReason, ConnectionOptions, Error, WsError}, + crate::{error::Error, HttpRequest, MessageIdGenerator}, futures_util::{stream::FusedStream, SinkExt, Stream, StreamExt}, relay_rpc::{ domain::MessageId, @@ -34,10 +37,10 @@ pub type SocketStream = WebSocketStream>; /// Opens a connection to the Relay and returns [`ClientStream`] for the /// connection. -pub async fn create_stream(opts: ConnectionOptions) -> Result { - let (socket, _) = connect_async(opts.into_request()?) +pub async fn create_stream(request: HttpRequest<()>) -> Result { + let (socket, _) = connect_async(request) .await - .map_err(Error::ConnectionFailed)?; + .map_err(WebsocketClientError::ConnectionFailed)?; Ok(ClientStream::new(socket)) } @@ -139,10 +142,13 @@ impl ClientStream { /// Closes the connection. pub async fn close(&mut self, frame: Option>) -> Result<(), Error> { self.close_frame = frame.clone(); - self.socket.close(frame).await.map_err(Error::ClosingFailed) + self.socket + .close(frame) + .await + .map_err(|err| WebsocketClientError::ClosingFailed(err).into()) } - fn parse_inbound(&mut self, result: Result) -> Option { + fn parse_inbound(&mut self, result: Result) -> Option { match result { Ok(message) => match &message { Message::Binary(_) | Message::Text(_) => { @@ -223,11 +229,13 @@ impl ClientStream { _ => None, }, - Err(error) => Some(StreamEvent::InboundError(Error::Socket(error))), + Err(error) => Some(StreamEvent::InboundError( + WebsocketClientError::Transport(error).into(), + )), } } - fn poll_write(&mut self, cx: &mut Context<'_>) -> Poll> { + fn poll_write(&mut self, cx: &mut Context<'_>) -> Poll> { let mut should_flush = false; loop { @@ -284,9 +292,9 @@ impl Stream for ClientStream { } match self.poll_write(cx) { - Poll::Ready(Err(error)) => { - Poll::Ready(Some(StreamEvent::OutboundError(Error::Socket(error)))) - } + Poll::Ready(Err(error)) => Poll::Ready(Some(StreamEvent::OutboundError( + WebsocketClientError::Transport(error).into(), + ))), _ => Poll::Pending, } @@ -304,7 +312,10 @@ impl Drop for ClientStream { let reason = CloseReason(self.close_frame.take()); for (_, tx) in self.requests.drain() { - tx.send(Err(Error::ConnectionClosed(reason.clone()))).ok(); + tx.send(Err( + WebsocketClientError::ConnectionClosed(reason.clone()).into() + )) + .ok(); } } } diff --git a/relay_rpc/src/auth.rs b/relay_rpc/src/auth.rs index b886726..ebeaef2 100644 --- a/relay_rpc/src/auth.rs +++ b/relay_rpc/src/auth.rs @@ -1,19 +1,19 @@ -#[cfg(test)] -mod tests; - -#[cfg(feature = "cacao")] -pub mod cacao; -pub mod did; - use { - crate::domain::{AuthSubject, ClientId, ClientIdDecodingError, DecodedClientId}, + crate::{ + domain::DecodedClientId, + jwt::{JwtBasicClaims, JwtHeader}, + }, chrono::{DateTime, Utc}, ed25519_dalek::{ed25519::signature::Signature, Keypair, Signer}, serde::{Deserialize, Serialize}, - std::{collections::HashSet, fmt::Display, time::Duration}, + std::{fmt::Display, time::Duration}, }; pub use {chrono, ed25519_dalek, rand}; +#[cfg(feature = "cacao")] +pub mod cacao; +pub mod did; + #[derive(Debug, thiserror::Error)] pub enum Error { #[error("Invalid duration")] @@ -25,19 +25,10 @@ pub enum Error { pub const RELAY_WEBSOCKET_ADDRESS: &str = "wss://relay.walletconnect.com"; -pub const DID_DELIMITER: &str = ":"; -pub const DID_PREFIX: &str = "did"; -pub const DID_METHOD: &str = "key"; - pub const MULTICODEC_ED25519_BASE: &str = "z"; pub const MULTICODEC_ED25519_HEADER: [u8; 2] = [237, 1]; pub const MULTICODEC_ED25519_LENGTH: usize = 32; -pub const JWT_DELIMITER: &str = "."; -pub const JWT_HEADER_TYP: &str = "JWT"; -pub const JWT_HEADER_ALG: &str = "EdDSA"; -pub const JWT_VALIDATION_TIME_LEEWAY_SECS: i64 = 120; - pub const DEFAULT_TOKEN_AUD: &str = RELAY_WEBSOCKET_ADDRESS; pub const DEFAULT_TOKEN_TTL: Duration = Duration::from_secs(60 * 60); @@ -59,14 +50,14 @@ impl From for String { #[derive(Debug, Clone)] pub struct AuthToken { - sub: AuthSubject, + sub: String, aud: Option, iat: Option>, ttl: Option, } impl AuthToken { - pub fn new(sub: impl Into) -> Self { + pub fn new(sub: impl Into) -> Self { Self { sub: sub.into(), aud: None, @@ -95,94 +86,33 @@ impl AuthToken { let ttl = self.ttl.unwrap_or(DEFAULT_TOKEN_TTL); let aud = self.aud.as_deref().unwrap_or(DEFAULT_TOKEN_AUD); - encode_auth_token(key, self.sub.as_ref(), aud, iat, ttl) - } -} - -#[derive(Serialize, Deserialize)] -pub struct JwtHeader<'a> { - pub typ: &'a str, - pub alg: &'a str, -} - -impl<'a> JwtHeader<'a> { - pub fn is_valid(&self) -> bool { - self.typ == JWT_HEADER_TYP && self.alg == JWT_HEADER_ALG - } -} - -#[derive(Serialize, Deserialize)] -pub struct JwtClaims<'a> { - pub iss: &'a str, - pub sub: &'a str, - pub aud: &'a str, - pub iat: i64, - pub exp: i64, -} - -impl<'a> JwtClaims<'a> { - pub fn validate( - &self, - aud: &HashSet, - time_leeway: impl Into>, - ) -> Result<(), JwtVerificationError> { - let time_leeway = time_leeway - .into() - .unwrap_or(JWT_VALIDATION_TIME_LEEWAY_SECS); - let now = Utc::now().timestamp(); - - if now - time_leeway > self.exp { - return Err(JwtVerificationError::Expired); - } - - if now + time_leeway < self.iat { - return Err(JwtVerificationError::NotYetValid); - } - - if !aud.contains(self.aud) { - return Err(JwtVerificationError::InvalidAudience); - } - Ok(()) + encode_auth_token(key, &self.sub, aud, iat, ttl) } } pub fn encode_auth_token( key: &Keypair, - sub: &str, - aud: &str, + sub: impl Into, + aud: impl Into, iat: DateTime, ttl: Duration, ) -> Result { let encoder = &data_encoding::BASE64URL_NOPAD; let exp = iat + chrono::Duration::from_std(ttl).map_err(|_| Error::InvalidDuration)?; - let iss = { - let client_id = DecodedClientId(*key.public_key().as_bytes()); - - format!("{DID_PREFIX}{DID_DELIMITER}{DID_METHOD}{DID_DELIMITER}{client_id}") - }; - let claims = { - let data = JwtClaims { - iss: &iss, - sub, - aud, + let data = JwtBasicClaims { + iss: DecodedClientId::from_key(&key.public_key()).into(), + sub: sub.into(), + aud: aud.into(), iat: iat.timestamp(), - exp: exp.timestamp(), - }; - - encoder.encode(serde_json::to_string(&data)?.as_bytes()) - }; - - let header = { - let data = JwtHeader { - typ: JWT_HEADER_TYP, - alg: JWT_HEADER_ALG, + exp: Some(exp.timestamp()), }; encoder.encode(serde_json::to_string(&data)?.as_bytes()) }; + let header = encoder.encode(serde_json::to_string(&JwtHeader::default())?.as_bytes()); let message = format!("{header}.{claims}"); let signature = { @@ -193,125 +123,3 @@ pub fn encode_auth_token( Ok(SerializedAuthToken(format!("{message}.{signature}"))) } - -#[derive(Debug, thiserror::Error)] -pub enum JwtVerificationError { - #[error("Invalid format")] - Format, - - #[error("Invalid encoding")] - Encoding, - - #[error("Invalid JWT signing algorithm")] - Header, - - #[error("JWT Token is expired")] - Expired, - - #[error("JWT Token is not yet valid")] - NotYetValid, - - #[error("Invalid audience")] - InvalidAudience, - - #[error("Invalid signature")] - Signature, - - #[error("Invalid JSON")] - Serialization, - - #[error("Invalid issuer DID prefix")] - IssuerPrefix, - - #[error("Invalid issuer DID method")] - IssuerMethod, - - #[error("Invalid issuer format")] - IssuerFormat, - - #[error(transparent)] - PubKey(#[from] ClientIdDecodingError), -} - -#[derive(Debug)] -pub struct Jwt(pub String); - -impl Jwt { - pub fn decode(&self, aud: &HashSet) -> Result { - let mut parts = self.0.splitn(3, JWT_DELIMITER); - - let (Some(header), Some(claims)) = (parts.next(), parts.next()) else { - return Err(JwtVerificationError::Format); - }; - - let decoder = &data_encoding::BASE64URL_NOPAD; - - let header_len = decoder - .decode_len(header.len()) - .map_err(|_| JwtVerificationError::Encoding)?; - let claims_len = decoder - .decode_len(claims.len()) - .map_err(|_| JwtVerificationError::Encoding)?; - - let mut output = vec![0u8; header_len.max(claims_len)]; - - // Decode header. - data_encoding::BASE64URL_NOPAD - .decode_mut(header.as_bytes(), &mut output[..header_len]) - .map_err(|_| JwtVerificationError::Encoding)?; - - { - let header = serde_json::from_slice::(&output[..header_len]) - .map_err(|_| JwtVerificationError::Serialization)?; - - if !header.is_valid() { - return Err(JwtVerificationError::Header); - } - } - - // Decode claims. - data_encoding::BASE64URL_NOPAD - .decode_mut(claims.as_bytes(), &mut output[..claims_len]) - .map_err(|_| JwtVerificationError::Encoding)?; - - let claims = serde_json::from_slice::(&output[..claims_len]) - .map_err(|_| JwtVerificationError::Serialization)?; - - // Basic token validation: `iat`, `exp` and `aud`. - claims.validate(aud, None)?; - - let did_key = claims - .iss - .strip_prefix(DID_PREFIX) - .ok_or(JwtVerificationError::IssuerPrefix)? - .strip_prefix(DID_DELIMITER) - .ok_or(JwtVerificationError::IssuerFormat)? - .strip_prefix(DID_METHOD) - .ok_or(JwtVerificationError::IssuerMethod)? - .strip_prefix(DID_DELIMITER) - .ok_or(JwtVerificationError::IssuerFormat)?; - - let pub_key = did_key.parse::()?; - - let mut parts = self.0.rsplitn(2, JWT_DELIMITER); - - let (Some(signature), Some(message)) = (parts.next(), parts.next()) else { - return Err(JwtVerificationError::Format); - }; - - let key = jsonwebtoken::DecodingKey::from_ed_der(pub_key.as_ref()); - - // Finally, verify signature. - let sig_result = jsonwebtoken::crypto::verify( - signature, - message.as_bytes(), - &key, - jsonwebtoken::Algorithm::EdDSA, - ); - - match sig_result { - Ok(true) => Ok(pub_key.into()), - _ => Err(JwtVerificationError::Signature), - } - } -} diff --git a/relay_rpc/src/auth/did.rs b/relay_rpc/src/auth/did.rs index f55bc14..c2e1035 100644 --- a/relay_rpc/src/auth/did.rs +++ b/relay_rpc/src/auth/did.rs @@ -3,7 +3,7 @@ pub const DID_PREFIX: &str = "did"; pub const DID_METHOD_KEY: &str = "key"; pub const DID_METHOD_PKH: &str = "pkh"; -#[derive(Debug, thiserror::Error)] +#[derive(Debug, Clone, thiserror::Error)] pub enum DidError { #[error("Invalid issuer DID prefix")] Prefix, @@ -25,3 +25,7 @@ pub fn extract_did_data<'a>(did: &'a str, method: &str) -> Result<&'a str, DidEr .strip_prefix(DID_DELIMITER) .ok_or(DidError::Format) } + +pub fn combine_did_data(method: &str, data: &str) -> String { + format!("{DID_PREFIX}{DID_DELIMITER}{method}{DID_DELIMITER}{data}") +} diff --git a/relay_rpc/src/auth/tests.rs b/relay_rpc/src/auth/tests.rs deleted file mode 100644 index b84a4e8..0000000 --- a/relay_rpc/src/auth/tests.rs +++ /dev/null @@ -1,131 +0,0 @@ -use { - crate::{ - auth::{AuthToken, Jwt, JwtVerificationError, JWT_VALIDATION_TIME_LEEWAY_SECS}, - domain::{ClientIdDecodingError, DecodedAuthSubject}, - }, - ed25519_dalek::Keypair, - std::{collections::HashSet, time::Duration}, -}; - -#[test] -fn token_validation() { - let aud = HashSet::from(["wss://relay.walletconnect.com".to_owned()]); - - // Invalid signature. - let jwt = Jwt("eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJkaWQ6a2V5Ono2TWtvZEhad25lVlJTaHRhTGY4SktZa3hwREdwMXZHWm5wR21kQnBYOE0yZXh4SCIsInN1YiI6ImM0NzlmZTVkYzQ2NGU3NzFlNzhiMTkzZDIzOWE2NWI1OGQyNzhjYWQxYzM0YmZiMGI1NzE2ZTViYjUxNDkyOGUiLCJhdWQiOiJ3c3M6Ly9yZWxheS53YWxsZXRjb25uZWN0LmNvbSIsImlhdCI6MTY1NjkxMDA5NywiZXhwIjo0ODEyNjcwMDk3fQ.CLryc7bGZ_mBVh-P5p2tDDkjY8m9ji9xZXixJCbLLd4TMBh7F0EkChbWOOUQp4DyXUVK4CN-hxMZgt2xnePUBAx".to_owned()); - assert!(matches!( - jwt.decode(&aud), - Err(JwtVerificationError::Signature) - )); - - // Invalid multicodec header. - let jwt = Jwt("eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJkaWQ6a2V5Ono2TWt2eDRWVnVCQlBIekVvTERiNWdOQzRyUW1uSnN0YzFib29oS2ZjSlV0OU12NjUiLCJzdWIiOiJjNDc5ZmU1ZGM0NjRlNzcxZTc4YjE5M2QyMzlhNjViNThkMjc4Y2FkMWMzNGJmYjBiNTcxNmU1YmI1MTQ5MjhlIiwiYXVkIjoid3NzOi8vcmVsYXkud2FsbGV0Y29ubmVjdC5jb20iLCJpYXQiOjE2NTY5MTAwOTcsImV4cCI6NDgxMjY3MDA5N30.ixjxEISufsDpdsp4MRwD4Q100d8s7v4mSlIWIad6q8Nh__768pzPaCAVXQIZLxKPhuJQ92cZi7tVUJtAE1_UCg".to_owned()); - assert!(matches!( - jwt.decode(&aud), - Err(JwtVerificationError::PubKey( - ClientIdDecodingError::Encoding - )) - )); - - // Invalid multicodec base. - let jwt = Jwt("eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJkaWQ6a2V5Onh6Nk1rb2RIWnduZVZSU2h0YUxmOEpLWWt4cERHcDF2R1pucEdtZEJwWDhNMmV4eEgiLCJzdWIiOiJjNDc5ZmU1ZGM0NjRlNzcxZTc4YjE5M2QyMzlhNjViNThkMjc4Y2FkMWMzNGJmYjBiNTcxNmU1YmI1MTQ5MjhlIiwiYXVkIjoid3NzOi8vcmVsYXkud2FsbGV0Y29ubmVjdC5jb20iLCJpYXQiOjE2NTY5MTAwOTcsImV4cCI6NDgxMjY3MDA5N30.BINvB6JpUyp5Zs7qbIYMv7KybptioYFZP89ZFTMtvdGvEnRpYg70uzwSLdhZB1EPJZIrUMhybfT7Q1DYEqHwDw".to_owned()); - assert!(matches!( - jwt.decode(&aud), - Err(JwtVerificationError::PubKey(ClientIdDecodingError::Base)) - )); - - // Invalid DID prefix. - let jwt = Jwt("eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJ4ZGlkOmtleTp6Nk1rb2RIWnduZVZSU2h0YUxmOEpLWWt4cERHcDF2R1pucEdtZEJwWDhNMmV4eEgiLCJzdWIiOiJjNDc5ZmU1ZGM0NjRlNzcxZTc4YjE5M2QyMzlhNjViNThkMjc4Y2FkMWMzNGJmYjBiNTcxNmU1YmI1MTQ5MjhlIiwiYXVkIjoid3NzOi8vcmVsYXkud2FsbGV0Y29ubmVjdC5jb20iLCJpYXQiOjE2NTY5MTAwOTcsImV4cCI6NDgxMjY3MDA5N30.GGhlhz7kXCqCTUsn390O_hA9YQDa61d_DDiSVLsa70xrgFrGmjjoWWl1dsZn3RVq4V1IB0P1__NDJ2PK0OMiDA".to_owned()); - assert!(matches!( - jwt.decode(&aud), - Err(JwtVerificationError::IssuerPrefix) - )); - - // Invalid DID method - let jwt = Jwt("eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJkaWQ6eGtleTp6Nk1rb2RIWnduZVZSU2h0YUxmOEpLWWt4cERHcDF2R1pucEdtZEJwWDhNMmV4eEgiLCJzdWIiOiJjNDc5ZmU1ZGM0NjRlNzcxZTc4YjE5M2QyMzlhNjViNThkMjc4Y2FkMWMzNGJmYjBiNTcxNmU1YmI1MTQ5MjhlIiwiYXVkIjoid3NzOi8vcmVsYXkud2FsbGV0Y29ubmVjdC5jb20iLCJpYXQiOjE2NTY5MTAwOTcsImV4cCI6NDgxMjY3MDA5N30.rogEwjJLQFwbDm4psUty7MPkHrCrNiXxpwEYZ2nctppmF7MYvC3g7URZNYkKxMbFtNZ1hFCwsr1peEu3pVeJCg".to_owned()); - assert!(matches!( - jwt.decode(&aud), - Err(JwtVerificationError::IssuerMethod) - )); - - // Invalid issuer base58. - let jwt = Jwt("eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJkaWQ6a2V5Ono2TWtvZEhad25lVlJTaHRhTGY4SktZa3hwREdwMXZHWm5wR21kQnBYOE0yZXh4SGwiLCJzdWIiOiJjNDc5ZmU1ZGM0NjRlNzcxZTc4YjE5M2QyMzlhNjViNThkMjc4Y2FkMWMzNGJmYjBiNTcxNmU1YmI1MTQ5MjhlIiwiYXVkIjoid3NzOi8vcmVsYXkud2FsbGV0Y29ubmVjdC5jb20iLCJpYXQiOjE2NTY5MTAwOTcsImV4cCI6NDgxMjY3MDA5N30.nLdxz4f6yJ8HsWZJUvpSHjFjoat4PfJav-kyqdHj6JXcX5SyDvp3QNB9doyzRWb9jpbA36Av0qn4kqLl-pGuBg".to_owned()); - assert!(matches!( - jwt.decode(&aud), - Err(JwtVerificationError::PubKey( - ClientIdDecodingError::Encoding - )) - )); - - let keypair = Keypair::generate(&mut rand::thread_rng()); - - // IAT in future. - let jwt = AuthToken::new(DecodedAuthSubject::generate()) - .iat(chrono::Utc::now() + chrono::Duration::hours(1)) - .as_jwt(&keypair) - .unwrap(); - assert!(matches!( - Jwt(jwt.into()).decode(&aud), - Err(JwtVerificationError::NotYetValid) - )); - - // IAT leeway, valid. - let jwt = AuthToken::new(DecodedAuthSubject::generate()) - .iat(chrono::Utc::now() + chrono::Duration::seconds(JWT_VALIDATION_TIME_LEEWAY_SECS)) - .as_jwt(&keypair) - .unwrap(); - assert!(matches!(Jwt(jwt.into()).decode(&aud), Ok(_))); - - // IAT leeway, invalid. - let jwt = AuthToken::new(DecodedAuthSubject::generate()) - .iat(chrono::Utc::now() + chrono::Duration::seconds(JWT_VALIDATION_TIME_LEEWAY_SECS + 1)) - .as_jwt(&keypair) - .unwrap(); - assert!(matches!( - Jwt(jwt.into()).decode(&aud), - Err(JwtVerificationError::NotYetValid) - )); - - // Past expiration. - let jwt = AuthToken::new(DecodedAuthSubject::generate()) - .iat(chrono::Utc::now() - chrono::Duration::hours(2)) - .ttl(Duration::from_secs(3600)) - .as_jwt(&keypair) - .unwrap(); - assert!(matches!( - Jwt(jwt.into()).decode(&aud), - Err(JwtVerificationError::Expired) - )); - - // Expiration leeway, valid. - let jwt = AuthToken::new(DecodedAuthSubject::generate()) - .iat(chrono::Utc::now() - chrono::Duration::seconds(3600 + JWT_VALIDATION_TIME_LEEWAY_SECS)) - .ttl(Duration::from_secs(3600)) - .as_jwt(&keypair) - .unwrap(); - assert!(matches!(Jwt(jwt.into()).decode(&aud), Ok(_))); - - // Expiration leeway, invalid. - let jwt = AuthToken::new(DecodedAuthSubject::generate()) - .iat( - chrono::Utc::now() - - chrono::Duration::seconds(3600 + JWT_VALIDATION_TIME_LEEWAY_SECS + 1), - ) - .ttl(Duration::from_secs(3600)) - .as_jwt(&keypair) - .unwrap(); - assert!(matches!( - Jwt(jwt.into()).decode(&aud), - Err(JwtVerificationError::Expired) - )); - - // Invalid aud. - let jwt = AuthToken::new(DecodedAuthSubject::generate()) - .aud("wss://not.relay.walletconnect.com") - .as_jwt(&keypair) - .unwrap(); - assert!(matches!( - Jwt(jwt.into()).decode(&aud), - Err(JwtVerificationError::InvalidAudience) - )); -} diff --git a/relay_rpc/src/domain.rs b/relay_rpc/src/domain.rs index 4694aa9..146a198 100644 --- a/relay_rpc/src/domain.rs +++ b/relay_rpc/src/domain.rs @@ -1,17 +1,20 @@ use { crate::{ - auth::{MULTICODEC_ED25519_BASE, MULTICODEC_ED25519_HEADER, MULTICODEC_ED25519_LENGTH}, + auth::{ + did::{combine_did_data, extract_did_data, DidError, DID_METHOD_KEY}, + MULTICODEC_ED25519_BASE, + MULTICODEC_ED25519_HEADER, + MULTICODEC_ED25519_LENGTH, + }, new_type, }, derive_more::{AsMut, AsRef}, + ed25519_dalek::PublicKey, serde::{Deserialize, Serialize}, serde_aux::prelude::deserialize_number_from_string, std::{str::FromStr, sync::Arc}, }; -#[cfg(test)] -mod tests; - #[derive(Debug, Clone, thiserror::Error)] pub enum ClientIdDecodingError { #[error("Invalid issuer multicodec base")] @@ -23,6 +26,9 @@ pub enum ClientIdDecodingError { #[error("Invalid multicodec header")] Header, + #[error("Invalid DID key data: {0}")] + Did(#[from] DidError), + #[error("Invalid issuer pubkey length")] Length, } @@ -63,11 +69,85 @@ impl TryFrom for DecodedClientId { } } +impl From for ClientId { + fn from(val: DidKey) -> Self { + val.0.into() + } +} + +#[derive(Debug, Default, Clone, PartialEq, Eq, Hash, AsRef, AsMut, Serialize, Deserialize)] +#[as_ref(forward)] +#[as_mut(forward)] +pub struct DidKey( + #[serde(with = "crate::serde_helpers::client_id_as_did_key")] pub DecodedClientId, +); + +impl From for PublicKey { + fn from(val: DidKey) -> Self { + val.0.as_public_key() + } +} + +impl From for DidKey { + fn from(val: DecodedClientId) -> Self { + Self(val) + } +} + +impl TryFrom for DidKey { + type Error = ClientIdDecodingError; + + fn try_from(value: ClientId) -> Result { + value.decode().map(Self) + } +} + #[derive(Debug, Default, Clone, PartialEq, Eq, Hash, AsRef, AsMut, Serialize, Deserialize)] #[as_ref(forward)] #[as_mut(forward)] pub struct DecodedClientId(pub [u8; MULTICODEC_ED25519_LENGTH]); +impl DecodedClientId { + #[inline] + pub fn try_from_did_key(did: &str) -> Result { + extract_did_data(did, DID_METHOD_KEY)?.parse() + } + + #[inline] + pub fn to_did_key(&self) -> String { + combine_did_data(DID_METHOD_KEY, &self.to_string()) + } + + #[inline] + pub fn from_key(key: &PublicKey) -> Self { + Self(*key.as_bytes()) + } + + #[inline] + pub fn as_public_key(&self) -> PublicKey { + // We know that the length is correct, so we can just unwrap. + PublicKey::from_bytes(&self.0).unwrap() + } +} + +impl From for DecodedClientId { + fn from(key: PublicKey) -> Self { + Self::from_key(&key) + } +} + +impl From for PublicKey { + fn from(val: DecodedClientId) -> Self { + val.as_public_key() + } +} + +impl From for DecodedClientId { + fn from(val: DidKey) -> Self { + val.0 + } +} + impl FromStr for DecodedClientId { type Err = ClientIdDecodingError; @@ -248,3 +328,34 @@ impl_byte_array_newtype!(DecodedTopic, Topic, 32); impl_byte_array_newtype!(DecodedSubscription, SubscriptionId, 32); impl_byte_array_newtype!(DecodedAuthSubject, AuthSubject, 32); impl_byte_array_newtype!(DecodedProjectId, ProjectId, 16); + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn client_id_decoding() { + let client_id_str = "z6MkodHZwneVRShtaLf8JKYkxpDGp1vGZnpGmdBpX8M2exxH"; + let client_id_bin = client_id_str.parse::().unwrap(); + + assert_eq!(client_id_str, ClientId::from(client_id_bin).as_ref()); + + assert!(matches!( + "z6MkodHZwne".parse::(), + Err(ClientIdDecodingError::Length) + )); + } + + #[test] + fn topic_decoding() { + let topic_str = "85089843cebc89ce5bbffd55377b2e65c8a32c2d0a76742f2d6852b5f531a460"; + let topic_bin = topic_str.parse::().unwrap(); + + assert_eq!(topic_str, Topic::from(topic_bin).as_ref()); + + assert!(matches!( + "85089843ce".parse::(), + Err(DecodingError::Length) + )); + } +} diff --git a/relay_rpc/src/domain/tests.rs b/relay_rpc/src/domain/tests.rs deleted file mode 100644 index 556da9c..0000000 --- a/relay_rpc/src/domain/tests.rs +++ /dev/null @@ -1,27 +0,0 @@ -use super::*; - -#[test] -fn client_id_decoding() { - let client_id_str = "z6MkodHZwneVRShtaLf8JKYkxpDGp1vGZnpGmdBpX8M2exxH"; - let client_id_bin = client_id_str.parse::().unwrap(); - - assert_eq!(client_id_str, ClientId::from(client_id_bin).as_ref()); - - assert!(matches!( - "z6MkodHZwne".parse::(), - Err(ClientIdDecodingError::Length) - )); -} - -#[test] -fn topic_decoding() { - let topic_str = "85089843cebc89ce5bbffd55377b2e65c8a32c2d0a76742f2d6852b5f531a460"; - let topic_bin = topic_str.parse::().unwrap(); - - assert_eq!(topic_str, Topic::from(topic_bin).as_ref()); - - assert!(matches!( - "85089843ce".parse::(), - Err(DecodingError::Length) - )); -} diff --git a/relay_rpc/src/jwt.rs b/relay_rpc/src/jwt.rs new file mode 100644 index 0000000..e9405aa --- /dev/null +++ b/relay_rpc/src/jwt.rs @@ -0,0 +1,348 @@ +use { + crate::domain::DidKey, + chrono::Utc, + ed25519_dalek::{ed25519::signature::Signature, Keypair, PublicKey, Signer}, + serde::{de::DeserializeOwned, Deserialize, Serialize}, + std::collections::HashSet, +}; + +pub const JWT_DELIMITER: &str = "."; +pub const JWT_HEADER_TYP: &str = "JWT"; +pub const JWT_HEADER_ALG: &str = "EdDSA"; +pub const JWT_VALIDATION_TIME_LEEWAY_SECS: i64 = 120; + +#[derive(Debug, thiserror::Error)] +pub enum JwtError { + #[error("Invalid format")] + Format, + + #[error("Invalid encoding")] + Encoding, + + #[error("Invalid JWT signing algorithm")] + Header, + + #[error("JWT Token is expired")] + Expired, + + #[error("JWT Token is not yet valid")] + NotYetValid, + + #[error("Invalid audience")] + InvalidAudience, + + #[error("Invalid signature")] + Signature, + + #[error("Encoding keypair mismatch")] + InvalidKeypair, + + #[error("Serialization error: {0}")] + Serialization(#[from] serde_json::Error), +} + +#[derive(Serialize, Deserialize)] +pub struct JwtHeader<'a> { + #[serde(borrow)] + pub typ: &'a str, + #[serde(borrow)] + pub alg: &'a str, +} + +impl Default for JwtHeader<'_> { + fn default() -> Self { + Self { + typ: JWT_HEADER_TYP, + alg: JWT_HEADER_ALG, + } + } +} + +impl<'a> JwtHeader<'a> { + pub fn is_valid(&self) -> bool { + self.typ == JWT_HEADER_TYP && self.alg == JWT_HEADER_ALG + } +} + +/// Basic JWT claims that are common to all JWTs used by the Relay. +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub struct JwtBasicClaims { + /// Client ID matching the watch type. + pub iss: DidKey, + /// Relay URL. + pub aud: String, + /// Service URL. + pub sub: String, + /// Issued at, timestamp. + pub iat: i64, + /// Expiration, timestamp. + pub exp: Option, +} + +impl VerifyableClaims for JwtBasicClaims { + fn basic(&self) -> &JwtBasicClaims { + self + } +} + +pub trait VerifyableClaims: Serialize + DeserializeOwned { + /// Returns a reference to the basic claims, which may be a part of a larger + /// set of claims. + fn basic(&self) -> &JwtBasicClaims; + + /// Encodes the claims into a JWT string, signing it with the provided key. + /// Returns an error if the provided key does not match the public key in + /// the claims (`iss`), or if serialization fails. + fn encode(&self, key: &Keypair) -> Result { + let public_key = PublicKey::from_bytes(self.basic().iss.as_ref()) + .map_err(|_| JwtError::InvalidKeypair)?; + + // Make sure the keypair matches the public key in the claims. + if public_key != key.public_key() { + return Err(JwtError::InvalidKeypair); + } + + let encoder = &data_encoding::BASE64URL_NOPAD; + let header = encoder.encode(serde_json::to_string(&JwtHeader::default())?.as_bytes()); + let claims = encoder.encode(serde_json::to_string(self)?.as_bytes()); + let message = format!("{header}.{claims}"); + let signature = encoder.encode(key.sign(message.as_bytes()).as_bytes()); + + Ok(format!("{message}.{signature}")) + } + + /// Tries to parse the claims from a string, returning an error if the + /// parsing fails for any reason. + /// + /// Note: This does not perorm the actual verification of the claims. After + /// successful decoding, the claims should be verified using the + /// [`VerifyableClaims::verify_basic()`] method. + fn try_from_str(data: &str) -> Result + where + Self: Sized, + { + let mut parts = data.splitn(3, JWT_DELIMITER); + + let (Some(header), Some(claims)) = (parts.next(), parts.next()) else { + return Err(JwtError::Format); + }; + + let decoder = &data_encoding::BASE64URL_NOPAD; + + let header_len = decoder + .decode_len(header.len()) + .map_err(|_| JwtError::Encoding)?; + let claims_len = decoder + .decode_len(claims.len()) + .map_err(|_| JwtError::Encoding)?; + + let mut output = vec![0u8; header_len.max(claims_len)]; + + // Decode header. + data_encoding::BASE64URL_NOPAD + .decode_mut(header.as_bytes(), &mut output[..header_len]) + .map_err(|_| JwtError::Encoding)?; + + { + let header = serde_json::from_slice::(&output[..header_len]) + .map_err(JwtError::Serialization)?; + + if !header.is_valid() { + return Err(JwtError::Header); + } + } + + // Decode claims. + data_encoding::BASE64URL_NOPAD + .decode_mut(claims.as_bytes(), &mut output[..claims_len]) + .map_err(|_| JwtError::Encoding)?; + + let claims = serde_json::from_slice::(&output[..claims_len]) + .map_err(JwtError::Serialization)?; + + let mut parts = data.rsplitn(2, JWT_DELIMITER); + + let (Some(signature), Some(message)) = (parts.next(), parts.next()) else { + return Err(JwtError::Format); + }; + + let key = jsonwebtoken::DecodingKey::from_ed_der(claims.basic().iss.as_ref()); + + // Finally, verify signature. + let sig_result = jsonwebtoken::crypto::verify( + signature, + message.as_bytes(), + &key, + jsonwebtoken::Algorithm::EdDSA, + ); + + match sig_result { + Ok(true) => Ok(claims), + + _ => Err(JwtError::Signature), + } + } + + /// Performs basic verification of the claims. This includes the following + /// checks: + /// - The token is not expired (with a configurable leeway). This is + /// optional if the token has an `exp` value; + /// - The token is not used before it's valid; + /// - The token is issued for the correct audience. + fn verify_basic( + &self, + aud: &HashSet, + time_leeway: impl Into>, + ) -> Result<(), JwtError> { + let basic = self.basic(); + let time_leeway = time_leeway + .into() + .unwrap_or(JWT_VALIDATION_TIME_LEEWAY_SECS); + let now = Utc::now().timestamp(); + + if matches!(basic.exp, Some(exp) if now - time_leeway > exp) { + return Err(JwtError::Expired); + } + + if now + time_leeway < basic.iat { + return Err(JwtError::NotYetValid); + } + + if !aud.contains(&basic.aud) { + return Err(JwtError::InvalidAudience); + } + + Ok(()) + } +} + +#[cfg(test)] +mod test { + use { + crate::{ + auth::AuthToken, + domain::ClientId, + jwt::{JwtBasicClaims, JwtError, VerifyableClaims, JWT_VALIDATION_TIME_LEEWAY_SECS}, + }, + ed25519_dalek::Keypair, + std::{collections::HashSet, time::Duration}, + }; + + #[derive(Debug)] + pub struct Jwt(pub String); + + impl Jwt { + pub fn decode(&self, aud: &HashSet) -> Result { + let claims = JwtBasicClaims::try_from_str(&self.0)?; + claims.verify_basic(aud, None)?; + Ok(claims.iss.into()) + } + } + + #[test] + fn token_validation() { + let aud = HashSet::from(["wss://relay.walletconnect.com".to_owned()]); + + // Invalid signature. + let jwt = Jwt("eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJkaWQ6a2V5Ono2TWtvZEhad25lVlJTaHRhTGY4SktZa3hwREdwMXZHWm5wR21kQnBYOE0yZXh4SCIsInN1YiI6ImM0NzlmZTVkYzQ2NGU3NzFlNzhiMTkzZDIzOWE2NWI1OGQyNzhjYWQxYzM0YmZiMGI1NzE2ZTViYjUxNDkyOGUiLCJhdWQiOiJ3c3M6Ly9yZWxheS53YWxsZXRjb25uZWN0LmNvbSIsImlhdCI6MTY1NjkxMDA5NywiZXhwIjo0ODEyNjcwMDk3fQ.CLryc7bGZ_mBVh-P5p2tDDkjY8m9ji9xZXixJCbLLd4TMBh7F0EkChbWOOUQp4DyXUVK4CN-hxMZgt2xnePUBAx".to_owned()); + assert!(matches!(jwt.decode(&aud), Err(JwtError::Signature))); + + // Invalid multicodec header. + let jwt = Jwt("eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJkaWQ6a2V5Ono2TWt2eDRWVnVCQlBIekVvTERiNWdOQzRyUW1uSnN0YzFib29oS2ZjSlV0OU12NjUiLCJzdWIiOiJjNDc5ZmU1ZGM0NjRlNzcxZTc4YjE5M2QyMzlhNjViNThkMjc4Y2FkMWMzNGJmYjBiNTcxNmU1YmI1MTQ5MjhlIiwiYXVkIjoid3NzOi8vcmVsYXkud2FsbGV0Y29ubmVjdC5jb20iLCJpYXQiOjE2NTY5MTAwOTcsImV4cCI6NDgxMjY3MDA5N30.ixjxEISufsDpdsp4MRwD4Q100d8s7v4mSlIWIad6q8Nh__768pzPaCAVXQIZLxKPhuJQ92cZi7tVUJtAE1_UCg".to_owned()); + assert!(matches!(jwt.decode(&aud), Err(JwtError::Serialization(..)))); + + // Invalid multicodec base. + let jwt = Jwt("eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJkaWQ6a2V5Onh6Nk1rb2RIWnduZVZSU2h0YUxmOEpLWWt4cERHcDF2R1pucEdtZEJwWDhNMmV4eEgiLCJzdWIiOiJjNDc5ZmU1ZGM0NjRlNzcxZTc4YjE5M2QyMzlhNjViNThkMjc4Y2FkMWMzNGJmYjBiNTcxNmU1YmI1MTQ5MjhlIiwiYXVkIjoid3NzOi8vcmVsYXkud2FsbGV0Y29ubmVjdC5jb20iLCJpYXQiOjE2NTY5MTAwOTcsImV4cCI6NDgxMjY3MDA5N30.BINvB6JpUyp5Zs7qbIYMv7KybptioYFZP89ZFTMtvdGvEnRpYg70uzwSLdhZB1EPJZIrUMhybfT7Q1DYEqHwDw".to_owned()); + assert!(matches!(jwt.decode(&aud), Err(JwtError::Serialization(..)))); + + // Invalid DID prefix. + let jwt = Jwt("eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJ4ZGlkOmtleTp6Nk1rb2RIWnduZVZSU2h0YUxmOEpLWWt4cERHcDF2R1pucEdtZEJwWDhNMmV4eEgiLCJzdWIiOiJjNDc5ZmU1ZGM0NjRlNzcxZTc4YjE5M2QyMzlhNjViNThkMjc4Y2FkMWMzNGJmYjBiNTcxNmU1YmI1MTQ5MjhlIiwiYXVkIjoid3NzOi8vcmVsYXkud2FsbGV0Y29ubmVjdC5jb20iLCJpYXQiOjE2NTY5MTAwOTcsImV4cCI6NDgxMjY3MDA5N30.GGhlhz7kXCqCTUsn390O_hA9YQDa61d_DDiSVLsa70xrgFrGmjjoWWl1dsZn3RVq4V1IB0P1__NDJ2PK0OMiDA".to_owned()); + assert!(matches!(jwt.decode(&aud), Err(JwtError::Serialization(..)))); + + // Invalid DID method + let jwt = Jwt("eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJkaWQ6eGtleTp6Nk1rb2RIWnduZVZSU2h0YUxmOEpLWWt4cERHcDF2R1pucEdtZEJwWDhNMmV4eEgiLCJzdWIiOiJjNDc5ZmU1ZGM0NjRlNzcxZTc4YjE5M2QyMzlhNjViNThkMjc4Y2FkMWMzNGJmYjBiNTcxNmU1YmI1MTQ5MjhlIiwiYXVkIjoid3NzOi8vcmVsYXkud2FsbGV0Y29ubmVjdC5jb20iLCJpYXQiOjE2NTY5MTAwOTcsImV4cCI6NDgxMjY3MDA5N30.rogEwjJLQFwbDm4psUty7MPkHrCrNiXxpwEYZ2nctppmF7MYvC3g7URZNYkKxMbFtNZ1hFCwsr1peEu3pVeJCg".to_owned()); + assert!(matches!(jwt.decode(&aud), Err(JwtError::Serialization(..)))); + + // Invalid issuer base58. + let jwt = Jwt("eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJkaWQ6a2V5Ono2TWtvZEhad25lVlJTaHRhTGY4SktZa3hwREdwMXZHWm5wR21kQnBYOE0yZXh4SGwiLCJzdWIiOiJjNDc5ZmU1ZGM0NjRlNzcxZTc4YjE5M2QyMzlhNjViNThkMjc4Y2FkMWMzNGJmYjBiNTcxNmU1YmI1MTQ5MjhlIiwiYXVkIjoid3NzOi8vcmVsYXkud2FsbGV0Y29ubmVjdC5jb20iLCJpYXQiOjE2NTY5MTAwOTcsImV4cCI6NDgxMjY3MDA5N30.nLdxz4f6yJ8HsWZJUvpSHjFjoat4PfJav-kyqdHj6JXcX5SyDvp3QNB9doyzRWb9jpbA36Av0qn4kqLl-pGuBg".to_owned()); + assert!(matches!(jwt.decode(&aud), Err(JwtError::Serialization(..)))); + + let keypair = Keypair::generate(&mut rand::thread_rng()); + let sub: String = "test".to_owned(); + + // IAT in future. + let jwt = AuthToken::new(sub.clone()) + .iat(chrono::Utc::now() + chrono::Duration::hours(1)) + .as_jwt(&keypair) + .unwrap(); + assert!(matches!( + Jwt(jwt.into()).decode(&aud), + Err(JwtError::NotYetValid) + )); + + // IAT leeway, valid. + let jwt = AuthToken::new(sub.clone()) + .iat(chrono::Utc::now() + chrono::Duration::seconds(JWT_VALIDATION_TIME_LEEWAY_SECS)) + .as_jwt(&keypair) + .unwrap(); + assert!(Jwt(jwt.into()).decode(&aud).is_ok()); + + // IAT leeway, invalid. + let jwt = AuthToken::new(sub.clone()) + .iat( + chrono::Utc::now() + chrono::Duration::seconds(JWT_VALIDATION_TIME_LEEWAY_SECS + 1), + ) + .as_jwt(&keypair) + .unwrap(); + assert!(matches!( + Jwt(jwt.into()).decode(&aud), + Err(JwtError::NotYetValid) + )); + + // Past expiration. + let jwt = AuthToken::new(sub.clone()) + .iat(chrono::Utc::now() - chrono::Duration::hours(2)) + .ttl(Duration::from_secs(3600)) + .as_jwt(&keypair) + .unwrap(); + assert!(matches!( + Jwt(jwt.into()).decode(&aud), + Err(JwtError::Expired) + )); + + // Expiration leeway, valid. + let jwt = AuthToken::new(sub.clone()) + .iat( + chrono::Utc::now() + - chrono::Duration::seconds(3600 + JWT_VALIDATION_TIME_LEEWAY_SECS), + ) + .ttl(Duration::from_secs(3600)) + .as_jwt(&keypair) + .unwrap(); + assert!(Jwt(jwt.into()).decode(&aud).is_ok()); + + // Expiration leeway, invalid. + let jwt = AuthToken::new(sub.clone()) + .iat( + chrono::Utc::now() + - chrono::Duration::seconds(3600 + JWT_VALIDATION_TIME_LEEWAY_SECS + 1), + ) + .ttl(Duration::from_secs(3600)) + .as_jwt(&keypair) + .unwrap(); + assert!(matches!( + Jwt(jwt.into()).decode(&aud), + Err(JwtError::Expired) + )); + + // Invalid aud. + let jwt = AuthToken::new(sub) + .aud("wss://not.relay.walletconnect.com") + .as_jwt(&keypair) + .unwrap(); + assert!(matches!( + Jwt(jwt.into()).decode(&aud), + Err(JwtError::InvalidAudience) + )); + } +} diff --git a/relay_rpc/src/lib.rs b/relay_rpc/src/lib.rs index b2be167..5010829 100644 --- a/relay_rpc/src/lib.rs +++ b/relay_rpc/src/lib.rs @@ -3,6 +3,8 @@ pub mod auth; pub mod domain; +pub mod jwt; pub mod macros; pub mod rpc; +pub mod serde_helpers; pub mod user_agent; diff --git a/relay_rpc/src/rpc.rs b/relay_rpc/src/rpc.rs index f327ba4..05cff32 100644 --- a/relay_rpc/src/rpc.rs +++ b/relay_rpc/src/rpc.rs @@ -1,14 +1,19 @@ //! The crate exports common types used when interacting with messages between //! clients. This also includes communication over HTTP between relays. +pub use watch::*; use { - crate::domain::{DecodingError, MessageId, SubscriptionId, Topic}, + crate::{ + domain::{DecodingError, DidKey, MessageId, SubscriptionId, Topic}, + jwt::JwtError, + }, serde::{de::DeserializeOwned, Deserialize, Serialize}, std::{fmt::Debug, sync::Arc}, }; #[cfg(test)] mod tests; +pub mod watch; /// Version of the WalletConnect protocol that we're implementing. pub const JSON_RPC_VERSION_STR: &str = "2.0"; @@ -26,6 +31,11 @@ pub const MAX_SUBSCRIPTION_BATCH_SIZE: usize = 500; /// See pub const MAX_FETCH_BATCH_SIZE: usize = 500; +/// The maximum number of receipts allowed for a batch receive request. +/// +/// See +pub const MAX_RECEIVE_BATCH_SIZE: usize = 500; + type BoxError = Box; /// Errors covering payload validation problems. @@ -494,6 +504,56 @@ impl RequestPayload for BatchFetchMessages { } } +/// Represents a message receipt. +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub struct Receipt { + /// The topic of the message to acknowledge. + pub topic: Topic, + + /// The ID of the message to acknowledge. + pub message_id: MessageId, +} + +/// Data structure representing publish request params. +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub struct BatchReceiveMessages { + /// The receipts to acknowledge. + pub receipts: Vec, +} + +impl RequestPayload for BatchReceiveMessages { + type Error = GenericError; + type Response = bool; + + fn validate(&self) -> Result<(), ValidationError> { + let batch_size = self.receipts.len(); + + if batch_size == 0 { + return Err(ValidationError::BatchEmpty); + } + + if batch_size > MAX_RECEIVE_BATCH_SIZE { + return Err(ValidationError::BatchLimitExceeded { + limit: MAX_RECEIVE_BATCH_SIZE, + actual: batch_size, + }); + } + + for receipt in &self.receipts { + receipt + .topic + .decode() + .map_err(ValidationError::TopicDecoding)?; + } + + Ok(()) + } + + fn into_params(self) -> Params { + Params::BatchReceiveMessages(self) + } +} + /// Data structure representing publish request params. #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] pub struct Publish { @@ -519,8 +579,25 @@ pub struct Publish { } impl Publish { - /// Creates a subscription payload for these publish params. + /// Converts these publish params into subscription params. pub fn as_subscription( + &self, + subscription_id: SubscriptionId, + published_at: i64, + ) -> Subscription { + Subscription { + id: subscription_id, + data: SubscriptionData { + topic: self.topic.clone(), + message: self.message.clone(), + published_at, + tag: self.tag, + }, + } + } + + /// Creates a subscription request from these publish params. + pub fn as_subscription_request( &self, message_id: MessageId, subscription_id: SubscriptionId, @@ -529,15 +606,7 @@ impl Publish { Request { id: message_id, jsonrpc: JSON_RPC_VERSION.clone(), - params: Params::Subscription(Subscription { - id: subscription_id, - data: SubscriptionData { - topic: self.topic.clone(), - message: self.message.clone(), - published_at, - tag: self.tag, - }, - }), + params: Params::Subscription(self.as_subscription(subscription_id, published_at)), } } } @@ -556,7 +625,7 @@ pub enum PublishError { impl From for GenericError { fn from(err: PublishError) -> Self { - GenericError::Request(Box::new(err)) + Self::Request(Box::new(err)) } } @@ -584,6 +653,73 @@ where *x == Default::default() } +#[derive(Debug, thiserror::Error)] +pub enum WatchError { + #[error("Invalid TTL")] + InvalidTtl, + + #[error("Service URL is invalid or too long")] + InvalidServiceUrl, + + #[error("Webhook URL is invalid or too long")] + InvalidWebhookUrl, + + #[error("Failed to decode JWT: {0}")] + Jwt(#[from] JwtError), + + #[error("{0}")] + Other(BoxError), +} + +impl From for GenericError { + fn from(err: WatchError) -> Self { + Self::Request(Box::new(err)) + } +} + +/// Data structure representing watch registration request params. +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct WatchRegister { + /// JWT with [`watch::WatchRegisterClaims`] payload. + pub register_auth: String, +} + +impl RequestPayload for WatchRegister { + type Error = WatchError; + /// The Relay's public key. + type Response = DidKey; + + fn validate(&self) -> Result<(), ValidationError> { + Ok(()) + } + + fn into_params(self) -> Params { + Params::WatchRegister(self) + } +} + +/// Data structure representing watch unregistration request params. +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct WatchUnregister { + /// JWT with [`watch::WatchUnregisterClaims`] payload. + pub unregister_auth: String, +} + +impl RequestPayload for WatchUnregister { + type Error = WatchError; + type Response = bool; + + fn validate(&self) -> Result<(), ValidationError> { + Ok(()) + } + + fn into_params(self) -> Params { + Params::WatchUnregister(self) + } +} + /// Data structure representing subscription request params. #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] pub struct Subscription { @@ -670,6 +806,18 @@ pub enum Params { #[serde(rename = "irn_publish", alias = "iridium_publish")] Publish(Publish), + /// Parameters to batch receive. + #[serde(rename = "irn_batchReceive", alias = "iridium_batchReceive")] + BatchReceiveMessages(BatchReceiveMessages), + + /// Parameters to watch register. + #[serde(rename = "irn_watchRegister", alias = "iridium_watchRegister")] + WatchRegister(WatchRegister), + + /// Parameters to watch unregister. + #[serde(rename = "irn_watchUnregister", alias = "iridium_watchUnregister")] + WatchUnregister(WatchUnregister), + /// Parameters for a subscription. The messages for any given topic sent to /// clients are wrapped into this format. A `publish` message to a topic /// results in a `subscription` message to each client subscribed to the @@ -720,6 +868,9 @@ impl Request { Params::BatchUnsubscribe(params) => params.validate(), Params::BatchFetchMessages(params) => params.validate(), Params::Publish(params) => params.validate(), + Params::BatchReceiveMessages(params) => params.validate(), + Params::WatchRegister(params) => params.validate(), + Params::WatchUnregister(params) => params.validate(), Params::Subscription(params) => params.validate(), } } diff --git a/relay_rpc/src/rpc/tests.rs b/relay_rpc/src/rpc/tests.rs index 861ce25..6fc6a01 100644 --- a/relay_rpc/src/rpc/tests.rs +++ b/relay_rpc/src/rpc/tests.rs @@ -112,6 +112,72 @@ fn subscription() { assert_eq!(&payload, &deserialized) } +#[test] +fn batch_receive() { + let payload: Payload = Payload::Request(Request::new( + 1.into(), + Params::BatchReceiveMessages(BatchReceiveMessages { + receipts: vec![Receipt { + topic: Topic::from( + "c4163cf65859106b3f5435fc296e7765411178ed452d1c30337a6230138c9840", + ), + message_id: MessageId::new(123), + }], + }), + )); + + let serialized = serde_json::to_string(&payload).unwrap(); + eprintln!("{serialized}"); + + assert_eq!( + &serialized, + r#"{"id":1,"jsonrpc":"2.0","method":"irn_batchReceive","params":{"receipts":[{"topic":"c4163cf65859106b3f5435fc296e7765411178ed452d1c30337a6230138c9840","message_id":123}]}}"# + ); + + let deserialized: Payload = serde_json::from_str(&serialized).unwrap(); + + assert_eq!(&payload, &deserialized) +} + +#[test] +fn watch_register() { + let params: WatchRegister = WatchRegister { + register_auth: "jwt".to_owned(), + }; + let payload: Payload = Payload::Request(Request::new(1.into(), Params::WatchRegister(params))); + + let serialized = serde_json::to_string(&payload).unwrap(); + + assert_eq!( + &serialized, + r#"{"id":1,"jsonrpc":"2.0","method":"irn_watchRegister","params":{"registerAuth":"jwt"}}"# + ); + + let deserialized: Payload = serde_json::from_str(&serialized).unwrap(); + + assert_eq!(&payload, &deserialized) +} + +#[test] +fn watch_unregister() { + let params: WatchUnregister = WatchUnregister { + unregister_auth: "jwt".to_owned(), + }; + let payload: Payload = + Payload::Request(Request::new(1.into(), Params::WatchUnregister(params))); + + let serialized = serde_json::to_string(&payload).unwrap(); + + assert_eq!( + &serialized, + r#"{"id":1,"jsonrpc":"2.0","method":"irn_watchUnregister","params":{"unregisterAuth":"jwt"}}"# + ); + + let deserialized: Payload = serde_json::from_str(&serialized).unwrap(); + + assert_eq!(&payload, &deserialized) +} + #[test] fn deserialize_iridium_method() { let serialized = r#"{"id":1,"jsonrpc":"2.0","method":"iridium_subscription","params":{"id":"test_id","data":{"topic":"test_topic","message":"test_message","publishedAt":123,"tag":1000}}}"#; @@ -536,7 +602,7 @@ fn validation() { // Batch fetch: invalid topic. let request = Request { id, - jsonrpc, + jsonrpc: jsonrpc.clone(), params: Params::BatchFetchMessages(BatchFetchMessages { topics: vec![Topic::from( "c4163cf65859106b3f5435fc296e7765411178ed452d1c30337a6230138c98401", @@ -547,4 +613,63 @@ fn validation() { request.validate(), Err(ValidationError::TopicDecoding(DecodingError::Length)) ); + + // Batch receive: valid. + let request = Request { + id, + jsonrpc: jsonrpc.clone(), + params: Params::BatchReceiveMessages(BatchReceiveMessages { + receipts: vec![Receipt { + topic: Topic::generate(), + message_id: MessageId::new(1), + }], + }), + }; + assert_eq!(request.validate(), Ok(())); + + // Batch receive: empty list. + let request = Request { + id, + jsonrpc: jsonrpc.clone(), + params: Params::BatchReceiveMessages(BatchReceiveMessages { receipts: vec![] }), + }; + assert_eq!(request.validate(), Err(ValidationError::BatchEmpty)); + + // Batch receive: too many items. + let receipts = (0..MAX_RECEIVE_BATCH_SIZE + 1) + .map(|_| Receipt { + topic: Topic::generate(), + message_id: MessageId::new(1), + }) + .collect(); + let request = Request { + id, + jsonrpc: jsonrpc.clone(), + params: Params::BatchReceiveMessages(BatchReceiveMessages { receipts }), + }; + assert_eq!( + request.validate(), + Err(ValidationError::BatchLimitExceeded { + limit: MAX_RECEIVE_BATCH_SIZE, + actual: MAX_RECEIVE_BATCH_SIZE + 1 + }) + ); + + // Batch receive: invalid topic. + let request = Request { + id, + jsonrpc, + params: Params::BatchReceiveMessages(BatchReceiveMessages { + receipts: vec![Receipt { + topic: Topic::from( + "c4163cf65859106b3f5435fc296e7765411178ed452d1c30337a6230138c98401", + ), + message_id: MessageId::new(1), + }], + }), + }; + assert_eq!( + request.validate(), + Err(ValidationError::TopicDecoding(DecodingError::Length)) + ); } diff --git a/relay_rpc/src/rpc/watch.rs b/relay_rpc/src/rpc/watch.rs new file mode 100644 index 0000000..2920364 --- /dev/null +++ b/relay_rpc/src/rpc/watch.rs @@ -0,0 +1,237 @@ +use { + crate::{ + domain::Topic, + jwt::{JwtBasicClaims, VerifyableClaims}, + }, + serde::{Deserialize, Serialize}, + std::sync::Arc, +}; + +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum WatchType { + Subscriber, + Publisher, +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum WatchStatus { + Accepted, + Queued, + Delivered, +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub enum WatchAction { + #[serde(rename = "irn_watchRegister")] + Register, + #[serde(rename = "irn_watchUnregister")] + Unregister, + #[serde(rename = "irn_watchEvent")] + WatchEvent, +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub struct WatchRegisterClaims { + /// Basic JWT claims. + #[serde(flatten)] + pub basic: JwtBasicClaims, + /// Action. Must be `irn_watchRegister`. + pub act: WatchAction, + /// Watcher type. Either subscriber or publisher. + pub typ: WatchType, + /// Webhook URL. + pub whu: String, + /// Array of message tags to watch. + pub tag: Vec, + /// Array of statuses to watch. + pub sts: Vec, +} + +impl VerifyableClaims for WatchRegisterClaims { + fn basic(&self) -> &JwtBasicClaims { + &self.basic + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub struct WatchUnregisterClaims { + /// Basic JWT claims. + #[serde(flatten)] + pub basic: JwtBasicClaims, + /// Action. Must be `irn_watchUnregister`. + pub act: WatchAction, + /// Watcher type. Either subscriber or publisher. + pub typ: WatchType, + /// Webhook URL. + pub whu: String, +} + +impl VerifyableClaims for WatchUnregisterClaims { + fn basic(&self) -> &JwtBasicClaims { + &self.basic + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct WatchEventPayload { + /// Webhook status. Either `accepted`, `queued` or `delivered`. + pub status: WatchStatus, + /// Topic of the message that triggered the watch event. + pub topic: Topic, + /// The published message. + pub message: Arc, + /// Message publishing timestamp. + pub published_at: i64, + /// Message tag. + pub tag: u32, +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub struct WatchEventClaims { + /// Basic JWT claims. + #[serde(flatten)] + pub basic: JwtBasicClaims, + /// Action. Must be `irn_watchEvent`. + pub act: WatchAction, + /// Watcher type. Either subscriber or publisher. + pub typ: WatchType, + /// Webhook URL. + pub whu: String, + /// Event payload. + pub evt: WatchEventPayload, +} + +impl VerifyableClaims for WatchEventClaims { + fn basic(&self) -> &JwtBasicClaims { + &self.basic + } +} + +#[cfg(test)] +mod test { + use { + super::*, + crate::{auth::RELAY_WEBSOCKET_ADDRESS, domain::DecodedClientId}, + chrono::DateTime, + ed25519_dalek::Keypair, + }; + + const KEYPAIR: [u8; 64] = [ + 215, 142, 127, 216, 153, 183, 205, 110, 103, 118, 181, 195, 60, 71, 5, 221, 100, 196, 207, + 81, 229, 11, 116, 121, 235, 104, 1, 121, 25, 18, 218, 83, 216, 230, 100, 248, 132, 110, 55, + 65, 221, 87, 66, 160, 36, 95, 116, 86, 169, 49, 107, 17, 13, 50, 22, 147, 199, 109, 125, + 155, 89, 190, 186, 171, + ]; + + #[test] + fn watch_register_jwt() { + let key = Keypair::from_bytes(&KEYPAIR).unwrap(); + let iat = DateTime::parse_from_rfc3339("2000-01-01T00:00:00Z").unwrap(); + let exp = DateTime::parse_from_rfc3339("3000-01-01T00:00:00Z").unwrap(); + + let claims = WatchRegisterClaims { + basic: JwtBasicClaims { + iss: DecodedClientId::from_key(&key.public_key()).into(), + aud: RELAY_WEBSOCKET_ADDRESS.to_owned(), + sub: "https://example.com".to_owned(), + iat: iat.timestamp(), + exp: Some(exp.timestamp()), + }, + act: WatchAction::Register, + typ: WatchType::Subscriber, + whu: "https://example.com".to_owned(), + tag: vec![1100], + sts: vec![WatchStatus::Accepted], + }; + + // Verify that the fields are flattened, and that enums are serialized in + // lowercase. + assert_eq!( + serde_json::to_string(&claims).unwrap(), + r#"{"iss":"did:key:z6Mku3wsRZTAHjr6xrYWVUfyGeNSNz1GJRVfazp3N76AL9gE","aud":"wss://relay.walletconnect.com","sub":"https://example.com","iat":946684800,"exp":32503680000,"act":"irn_watchRegister","typ":"subscriber","whu":"https://example.com","tag":[1100],"sts":["accepted"]}"# + ); + + // Verify that the claims can be encoded and decoded correctly. + assert_eq!( + claims, + WatchRegisterClaims::try_from_str(&claims.encode(&key).unwrap()).unwrap() + ); + } + + #[test] + fn watch_unregister_jwt() { + let key = Keypair::from_bytes(&KEYPAIR).unwrap(); + let iat = DateTime::parse_from_rfc3339("2000-01-01T00:00:00Z").unwrap(); + let exp = DateTime::parse_from_rfc3339("3000-01-01T00:00:00Z").unwrap(); + + let claims = WatchUnregisterClaims { + basic: JwtBasicClaims { + iss: DecodedClientId::from_key(&key.public_key()).into(), + aud: RELAY_WEBSOCKET_ADDRESS.to_owned(), + sub: "https://example.com".to_owned(), + iat: iat.timestamp(), + exp: Some(exp.timestamp()), + }, + act: WatchAction::Unregister, + typ: WatchType::Publisher, + whu: "https://example.com".to_owned(), + }; + + // Verify that the fields are flattened, and that enums are serialized in + // lowercase. + assert_eq!( + serde_json::to_string(&claims).unwrap(), + r#"{"iss":"did:key:z6Mku3wsRZTAHjr6xrYWVUfyGeNSNz1GJRVfazp3N76AL9gE","aud":"wss://relay.walletconnect.com","sub":"https://example.com","iat":946684800,"exp":32503680000,"act":"irn_watchUnregister","typ":"publisher","whu":"https://example.com"}"# + ); + + // Verify that the claims can be encoded and decoded correctly. + assert_eq!( + claims, + WatchUnregisterClaims::try_from_str(&claims.encode(&key).unwrap()).unwrap() + ); + } + + #[test] + fn watch_event_jwt() { + let key = Keypair::from_bytes(&KEYPAIR).unwrap(); + let iat = DateTime::parse_from_rfc3339("2000-01-01T00:00:00Z").unwrap(); + let exp = DateTime::parse_from_rfc3339("3000-01-01T00:00:00Z").unwrap(); + let topic = Topic::from("474e88153f4db893de42c35e1891dc0e37a02e11961385de0475460fb48b8639"); + + let claims = WatchEventClaims { + basic: JwtBasicClaims { + iss: DecodedClientId::from_key(&key.public_key()).into(), + aud: RELAY_WEBSOCKET_ADDRESS.to_owned(), + sub: "https://example.com".to_owned(), + iat: iat.timestamp(), + exp: Some(exp.timestamp()), + }, + act: WatchAction::WatchEvent, + whu: "https://example.com".to_owned(), + typ: WatchType::Subscriber, + evt: WatchEventPayload { + status: WatchStatus::Accepted, + topic, + message: Arc::from("test message"), + published_at: iat.timestamp(), + tag: 1100, + }, + }; + + // Verify that the fields are flattened, and that enums are serialized in + // lowercase. + assert_eq!( + serde_json::to_string(&claims).unwrap(), + r#"{"iss":"did:key:z6Mku3wsRZTAHjr6xrYWVUfyGeNSNz1GJRVfazp3N76AL9gE","aud":"wss://relay.walletconnect.com","sub":"https://example.com","iat":946684800,"exp":32503680000,"act":"irn_watchEvent","typ":"subscriber","whu":"https://example.com","evt":{"status":"accepted","topic":"474e88153f4db893de42c35e1891dc0e37a02e11961385de0475460fb48b8639","message":"test message","publishedAt":946684800,"tag":1100}}"# + ); + + // Verify that the claims can be encoded and decoded correctly. + assert_eq!( + claims, + WatchEventClaims::try_from_str(&claims.encode(&key).unwrap()).unwrap() + ); + } +} diff --git a/relay_rpc/src/serde_helpers.rs b/relay_rpc/src/serde_helpers.rs new file mode 100644 index 0000000..0fc16b5 --- /dev/null +++ b/relay_rpc/src/serde_helpers.rs @@ -0,0 +1,56 @@ +pub mod client_id_as_did_key { + use { + crate::domain::DecodedClientId, + serde::{Deserialize, Deserializer, Serialize, Serializer}, + }; + + pub fn serialize(data: &DecodedClientId, serializer: S) -> Result + where + S: Serializer, + { + data.to_did_key().serialize(serializer) + } + + pub fn deserialize<'de, D>(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + use serde::de::Error; + + DecodedClientId::try_from_did_key(&String::deserialize(deserializer)?) + .map_err(D::Error::custom) + } +} + +#[cfg(test)] +mod test { + use { + crate::domain::{ClientId, DecodedClientId}, + serde::{Deserialize, Serialize}, + }; + + #[test] + fn client_id_as_did_key() { + #[derive(Serialize, Deserialize)] + struct Data { + #[serde(with = "super::client_id_as_did_key")] + client_id: DecodedClientId, + } + + let client_id = ClientId::new("z6MkhaXgBZDvotDkL5257faiztiGiC2QtKLGpbnnEGta2doK".into()); + + let serialized = serde_json::to_string(&Data { + client_id: client_id.decode().unwrap(), + }) + .unwrap(); + + assert_eq!( + serialized, + r#"{"client_id":"did:key:z6MkhaXgBZDvotDkL5257faiztiGiC2QtKLGpbnnEGta2doK"}"# + ); + + let deserialized: Data = serde_json::from_str(&serialized).unwrap(); + + assert_eq!(deserialized.client_id, client_id.decode().unwrap(),); + } +}