diff --git a/Cargo.toml b/Cargo.toml index 134428a..bcf95ef 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,6 +30,8 @@ url = "2.3" warp = { version = "0.3", default-features = false } serde_json = "1.0" rand = "0.8.5" +futures-util = "0.3" +once_cell = "1.19" [[example]] name = "websocket_client" diff --git a/examples/websocket_client.rs b/examples/websocket_client.rs index 562006b..a49e88b 100644 --- a/examples/websocket_client.rs +++ b/examples/websocket_client.rs @@ -1,6 +1,6 @@ use { relay_client::{ - error::Error, + error::ClientError, websocket::{Client, CloseFrame, ConnectionHandler, PublishedMessage}, ConnectionOptions, }, @@ -49,11 +49,11 @@ impl ConnectionHandler for Handler { ); } - fn inbound_error(&mut self, error: Error) { + fn inbound_error(&mut self, error: ClientError) { println!("[{}] inbound error: {error}", self.name); } - fn outbound_error(&mut self, error: Error) { + fn outbound_error(&mut self, error: ClientError) { println!("[{}] outbound error: {error}", self.name); } } diff --git a/relay_client/src/error.rs b/relay_client/src/error.rs index a76e984..eeb79ce 100644 --- a/relay_client/src/error.rs +++ b/relay_client/src/error.rs @@ -1,3 +1,5 @@ +use relay_rpc::rpc::{self, error::ServiceError}; + pub type BoxError = Box; /// Errors generated while parsing @@ -23,7 +25,7 @@ pub enum RequestBuildError { /// Possible Relay client errors. #[derive(Debug, thiserror::Error)] -pub enum Error { +pub enum ClientError { #[error("Failed to build connection request: {0}")] RequestBuilder(#[from] RequestBuildError), @@ -42,15 +44,69 @@ pub enum Error { #[error("Invalid response ID")] InvalidResponseId, + #[error("Invalid error response")] + InvalidErrorResponse, + #[error("Serialization failed: {0}")] Serialization(serde_json::Error), #[error("Deserialization failed: {0}")] Deserialization(serde_json::Error), - #[error("RPC error ({code}): {message}")] - Rpc { code: i32, message: String }, + #[error("RPC error: code={code} data={data:?} message={message}")] + Rpc { + code: i32, + message: String, + data: Option, + }, #[error("Invalid request type")] InvalidRequestType, } + +impl From for ClientError { + fn from(err: rpc::ErrorData) -> Self { + Self::Rpc { + code: err.code, + message: err.message, + data: err.data, + } + } +} + +#[derive(Debug, thiserror::Error)] +pub enum Error { + /// Client errors encountered while performing the request. + #[error(transparent)] + Client(ClientError), + + /// Error response received from the relay. + #[error(transparent)] + Response(#[from] rpc::Error), +} + +impl From for Error { + fn from(err: ClientError) -> Self { + match err { + ClientError::Rpc { + code, + message, + data, + } => { + let err = rpc::ErrorData { + code, + message, + data, + }; + + match rpc::Error::try_from(err) { + Ok(err) => Error::Response(err), + + Err(_) => Error::Client(ClientError::InvalidErrorResponse), + } + } + + _ => Error::Client(err), + } + } +} diff --git a/relay_client/src/http.rs b/relay_client/src/http.rs index 26ec4d2..a615559 100644 --- a/relay_client/src/http.rs +++ b/relay_client/src/http.rs @@ -1,6 +1,6 @@ use { crate::{ - error::{BoxError, Error}, + error::{BoxError, ClientError, Error}, ConnectionOptions, MessageIdGenerator, }, @@ -9,15 +9,15 @@ use { auth::ed25519_dalek::SigningKey, domain::{DecodedClientId, SubscriptionId, Topic}, jwt::{self, JwtError, VerifyableClaims}, - rpc::{self, Receipt, RequestPayload}, + rpc::{self, Receipt, ServiceRequest}, }, std::{sync::Arc, time::Duration}, url::Url, }; pub type TransportError = reqwest::Error; -pub type Response = Result<::Response, Error>; -pub type EmptyResponse = Result<(), Error>; +pub type Response = Result<::Response, Error<::Error>>; +pub type EmptyResponse = Result<(), Error<::Error>>; #[derive(Debug, thiserror::Error)] pub enum RequestParamsError { @@ -41,9 +41,6 @@ pub enum HttpClientError { #[error("JWT error: {0}")] Jwt(#[from] JwtError), - - #[error("RPC error: code={} message={}", .0.code, .0.message)] - RpcError(rpc::ErrorData), } #[derive(Debug, Clone)] @@ -82,7 +79,7 @@ pub struct Client { } impl Client { - pub fn new(opts: &ConnectionOptions) -> Result { + pub fn new(opts: &ConnectionOptions) -> Result { let mut headers = HeaderMap::new(); opts.update_request_headers(&mut headers)?; @@ -111,11 +108,14 @@ impl Client { tag: u32, ttl: Duration, prompt: bool, - ) -> EmptyResponse { + ) -> EmptyResponse { let ttl_secs = ttl .as_secs() .try_into() - .map_err(|_| HttpClientError::InvalidRequest(RequestParamsError::InvalidTtl.into()))?; + .map_err(|_| { + HttpClientError::InvalidRequest(RequestParamsError::InvalidTtl.into()).into() + }) + .map_err(Error::Client)?; self.request(rpc::Publish { topic, @@ -175,7 +175,8 @@ impl Client { .ttl .as_secs() .try_into() - .map_err(|err| HttpClientError::InvalidRequest(Box::new(err)))?; + .map_err(|err| HttpClientError::InvalidRequest(Box::new(err)).into()) + .map_err(Error::Client)?; let exp = iat + ttl_sec; let claims = rpc::WatchRegisterClaims { @@ -194,7 +195,11 @@ impl Client { }; let payload = rpc::WatchRegister { - register_auth: claims.encode(keypair).map_err(HttpClientError::Jwt)?, + register_auth: claims + .encode(keypair) + .map_err(HttpClientError::Jwt) + .map_err(ClientError::from) + .map_err(Error::Client)?, }; self.request(payload).await @@ -230,7 +235,11 @@ impl Client { }; let payload = rpc::WatchUnregister { - unregister_auth: claims.encode(keypair).map_err(HttpClientError::Jwt)?, + unregister_auth: claims + .encode(keypair) + .map_err(HttpClientError::Jwt) + .map_err(ClientError::from) + .map_err(Error::Client)?, }; self.request(payload).await @@ -299,7 +308,7 @@ impl Client { pub(crate) async fn request(&self, payload: T) -> Response where - T: RequestPayload, + T: ServiceRequest, { let payload = rpc::Payload::Request(rpc::Request { id: self.id_generator.next(), @@ -307,37 +316,42 @@ impl Client { params: payload.into_params(), }); - let result = self - .client - .post(self.url.clone()) - .json(&payload) - .send() - .await - .map_err(HttpClientError::Transport)?; + let response = async { + let result = self + .client + .post(self.url.clone()) + .json(&payload) + .send() + .await + .map_err(HttpClientError::Transport)?; - let status = result.status(); + let status = result.status(); - if !status.is_success() { - let body = result.text().await; - return Err(HttpClientError::InvalidHttpCode(status, body).into()); - } + if !status.is_success() { + let body = result.text().await; + return Err(HttpClientError::InvalidHttpCode(status, body)); + } - let response = result - .json::() - .await - .map_err(|_| HttpClientError::InvalidResponse)?; + result + .json::() + .await + .map_err(|_| HttpClientError::InvalidResponse) + } + .await + .map_err(ClientError::from) + .map_err(Error::Client)?; match response { rpc::Payload::Response(rpc::Response::Success(response)) => { serde_json::from_value(response.result) - .map_err(|_| HttpClientError::InvalidResponse.into()) + .map_err(|_| Error::Client(HttpClientError::InvalidResponse.into())) } rpc::Payload::Response(rpc::Response::Error(response)) => { - Err(HttpClientError::RpcError(response.error).into()) + Err(ClientError::from(response.error).into()) } - _ => Err(HttpClientError::InvalidResponse.into()), + _ => Err(Error::Client(HttpClientError::InvalidResponse.into())), } } } diff --git a/relay_client/src/lib.rs b/relay_client/src/lib.rs index 6fccffd..fe6f8e9 100644 --- a/relay_client/src/lib.rs +++ b/relay_client/src/lib.rs @@ -1,5 +1,5 @@ use { - crate::error::{Error, RequestBuildError}, + crate::error::{ClientError, RequestBuildError}, ::http::HeaderMap, relay_rpc::{ auth::{SerializedAuthToken, RELAY_WEBSOCKET_ADDRESS}, diff --git a/relay_client/src/websocket.rs b/relay_client/src/websocket.rs index c3f71c4..bde3eda 100644 --- a/relay_client/src/websocket.rs +++ b/relay_client/src/websocket.rs @@ -1,6 +1,6 @@ use { self::connection::{connection_event_loop, ConnectionControl}, - crate::{error::Error, ConnectionOptions}, + crate::{error::ClientError, ConnectionOptions}, relay_rpc::{ domain::{MessageId, SubscriptionId, Topic}, rpc::{ @@ -114,11 +114,11 @@ pub trait ConnectionHandler: Send + 'static { /// Called when an inbound error occurs, such as data deserialization /// failure, or an unknown response message ID. - fn inbound_error(&mut self, _error: Error) {} + fn inbound_error(&mut self, _error: ClientError) {} /// Called when an outbound error occurs, i.e. failed to write to the /// websocket stream. - fn outbound_error(&mut self, _error: Error) {} + fn outbound_error(&mut self, _error: ClientError) {} } /// The Relay WebSocket RPC client. @@ -291,7 +291,7 @@ impl Client { } /// Opens a connection to the Relay. - pub async fn connect(&self, opts: &ConnectionOptions) -> Result<(), Error> { + pub async fn connect(&self, opts: &ConnectionOptions) -> Result<(), ClientError> { let (tx, rx) = oneshot::channel(); let request = opts.as_ws_request()?; @@ -300,14 +300,14 @@ impl Client { .send(ConnectionControl::Connect { request, tx }) .is_ok() { - rx.await.map_err(|_| Error::ChannelClosed)? + rx.await.map_err(|_| ClientError::ChannelClosed)? } else { - Err(Error::ChannelClosed) + Err(ClientError::ChannelClosed) } } /// Closes the Relay connection. - pub async fn disconnect(&self) -> Result<(), Error> { + pub async fn disconnect(&self) -> Result<(), ClientError> { let (tx, rx) = oneshot::channel(); if self @@ -315,9 +315,9 @@ impl Client { .send(ConnectionControl::Disconnect { tx }) .is_ok() { - rx.await.map_err(|_| Error::ChannelClosed)? + rx.await.map_err(|_| ClientError::ChannelClosed)? } else { - Err(Error::ChannelClosed) + Err(ClientError::ChannelClosed) } } @@ -330,7 +330,7 @@ impl Client { unreachable!(); }; - request.tx.send(Err(Error::ChannelClosed)).ok(); + request.tx.send(Err(ClientError::ChannelClosed)).ok(); } } } diff --git a/relay_client/src/websocket/connection.rs b/relay_client/src/websocket/connection.rs index 22a2b3d..e08fb2e 100644 --- a/relay_client/src/websocket/connection.rs +++ b/relay_client/src/websocket/connection.rs @@ -8,7 +8,7 @@ use { }, crate::{ websocket::{stream::StreamEvent, PublishedMessage}, - Error, + ClientError, HttpRequest, }, futures_util::{stream::FusedStream, Stream, StreamExt}, @@ -22,11 +22,11 @@ use { pub(super) enum ConnectionControl { Connect { request: HttpRequest<()>, - tx: oneshot::Sender>, + tx: oneshot::Sender>, }, Disconnect { - tx: oneshot::Sender>, + tx: oneshot::Sender>, }, OutboundRequest(OutboundRequest), @@ -107,7 +107,7 @@ impl Connection { Self { stream: None } } - async fn connect(&mut self, request: HttpRequest<()>) -> Result<(), Error> { + async fn connect(&mut self, request: HttpRequest<()>) -> Result<(), ClientError> { if let Some(mut stream) = self.stream.take() { stream.close(None).await?; } @@ -117,7 +117,7 @@ impl Connection { Ok(()) } - async fn disconnect(&mut self) -> Result<(), Error> { + async fn disconnect(&mut self) -> Result<(), ClientError> { let stream = self.stream.take(); match stream { diff --git a/relay_client/src/websocket/fetch.rs b/relay_client/src/websocket/fetch.rs index 6030455..08e2df8 100644 --- a/relay_client/src/websocket/fetch.rs +++ b/relay_client/src/websocket/fetch.rs @@ -1,10 +1,10 @@ use { super::{create_request, Client, ResponseFuture}, - crate::Error, + crate::error::Error, futures_util::{FutureExt, Stream}, relay_rpc::{ domain::Topic, - rpc::{BatchFetchMessages, SubscriptionData}, + rpc::{BatchFetchMessages, ServiceRequest, SubscriptionData}, }, std::{ pin::Pin, @@ -48,7 +48,7 @@ impl FetchMessageStream { } impl Stream for FetchMessageStream { - type Item = Result; + type Item = Result::Error>>; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { loop { diff --git a/relay_client/src/websocket/inbound.rs b/relay_client/src/websocket/inbound.rs index f5d28bf..2581c7d 100644 --- a/relay_client/src/websocket/inbound.rs +++ b/relay_client/src/websocket/inbound.rs @@ -1,8 +1,8 @@ use { - crate::Error, + crate::ClientError, relay_rpc::{ domain::MessageId, - rpc::{ErrorResponse, Payload, RequestPayload, Response, SuccessfulResponse}, + rpc::{self, ErrorResponse, Payload, Response, ServiceRequest, SuccessfulResponse}, }, tokio::sync::mpsc::UnboundedSender, tokio_tungstenite::tungstenite::Message, @@ -24,7 +24,7 @@ pub struct InboundRequest { impl InboundRequest where - T: RequestPayload, + T: ServiceRequest, { pub(super) fn new(id: MessageId, data: T, tx: UnboundedSender) -> Self { Self { id, tx, data } @@ -45,20 +45,23 @@ where /// /// Returns an error if the response can't be serialized, or if the /// underlying channel is closed. - pub fn respond(self, response: Result) -> Result<(), Error> { + pub fn respond(self, response: Result) -> Result<(), ClientError> { let response = match response { Ok(data) => Response::Success(SuccessfulResponse::new( self.id, - serde_json::to_value(data).map_err(Error::Serialization)?, + serde_json::to_value(data).map_err(ClientError::Serialization)?, )), - Err(err) => Response::Error(ErrorResponse::new(self.id, err.into())), + Err(err) => Response::Error(ErrorResponse::new(self.id, rpc::Error::Handler(err))), }; let message = Message::Text( - serde_json::to_string(&Payload::Response(response)).map_err(Error::Serialization)?, + serde_json::to_string(&Payload::Response(response)) + .map_err(ClientError::Serialization)?, ); - self.tx.send(message).map_err(|_| Error::ChannelClosed) + self.tx + .send(message) + .map_err(|_| ClientError::ChannelClosed) } } diff --git a/relay_client/src/websocket/outbound.rs b/relay_client/src/websocket/outbound.rs index dbbc244..6a927c6 100644 --- a/relay_client/src/websocket/outbound.rs +++ b/relay_client/src/websocket/outbound.rs @@ -1,7 +1,7 @@ use { - crate::Error, + crate::{error::Error, ClientError}, pin_project::pin_project, - relay_rpc::rpc::{Params, RequestPayload}, + relay_rpc::rpc::{Params, ServiceRequest}, std::{ future::Future, marker::PhantomData, @@ -16,13 +16,13 @@ use { #[derive(Debug)] pub struct OutboundRequest { pub(super) params: Params, - pub(super) tx: oneshot::Sender>, + pub(super) tx: oneshot::Sender>, } impl OutboundRequest { pub(super) fn new( params: Params, - tx: oneshot::Sender>, + tx: oneshot::Sender>, ) -> Self { Self { params, tx } } @@ -33,12 +33,12 @@ impl OutboundRequest { #[pin_project] pub struct ResponseFuture { #[pin] - rx: oneshot::Receiver>, + rx: oneshot::Receiver>, _marker: PhantomData, } impl ResponseFuture { - pub(super) fn new(rx: oneshot::Receiver>) -> Self { + pub(super) fn new(rx: oneshot::Receiver>) -> Self { Self { rx, _marker: PhantomData, @@ -48,22 +48,22 @@ impl ResponseFuture { impl Future for ResponseFuture where - T: RequestPayload, + T: ServiceRequest, { - type Output = Result; + type Output = Result>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); - let result = ready!(this.rx.poll(cx)).map_err(|_| Error::ChannelClosed)?; + let result = ready!(this.rx.poll(cx)).map_err(|_| ClientError::ChannelClosed)?; let result = match result { - Ok(value) => serde_json::from_value(value).map_err(Error::Deserialization), + Ok(value) => serde_json::from_value(value).map_err(ClientError::Deserialization), Err(err) => Err(err), }; - Poll::Ready(result) + Poll::Ready(result.map_err(Into::into)) } } @@ -84,9 +84,9 @@ impl EmptyResponseFuture { impl Future for EmptyResponseFuture where - T: RequestPayload, + T: ServiceRequest, { - type Output = Result<(), Error>; + type Output = Result<(), Error>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { Poll::Ready(ready!(self.project().rx.poll(cx)).map(|_| ())) @@ -98,7 +98,7 @@ where /// [`ClientStream`][crate::client::ClientStream]. pub fn create_request(data: T) -> (OutboundRequest, ResponseFuture) where - T: RequestPayload, + T: ServiceRequest, { let (tx, rx) = oneshot::channel(); diff --git a/relay_client/src/websocket/stream.rs b/relay_client/src/websocket/stream.rs index 3601b0f..28a5e1e 100644 --- a/relay_client/src/websocket/stream.rs +++ b/relay_client/src/websocket/stream.rs @@ -6,11 +6,11 @@ use { TransportError, WebsocketClientError, }, - crate::{error::Error, HttpRequest, MessageIdGenerator}, + crate::{error::ClientError, HttpRequest, MessageIdGenerator}, futures_util::{stream::FusedStream, SinkExt, Stream, StreamExt}, relay_rpc::{ domain::MessageId, - rpc::{Params, Payload, Request, RequestPayload, Response, Subscription}, + rpc::{self, Params, Payload, Response, ServiceRequest, Subscription}, }, std::{ collections::{hash_map::Entry, HashMap}, @@ -58,11 +58,11 @@ pub enum StreamEvent { /// Error generated when failed to parse an inbound message, invalid request /// type or message ID. - InboundError(Error), + InboundError(ClientError), /// Error generated when failed to write data to the underlying websocket /// stream. - OutboundError(Error), + OutboundError(ClientError), /// The websocket connection was closed. /// @@ -81,7 +81,7 @@ pub struct ClientStream { socket: SocketStream, outbound_tx: UnboundedSender, outbound_rx: UnboundedReceiver, - requests: HashMap>>, + requests: HashMap>>, id_generator: MessageIdGenerator, close_frame: Option>, } @@ -107,13 +107,13 @@ impl ClientStream { pub fn send_raw(&mut self, request: OutboundRequest) { let tx = request.tx; let id = self.id_generator.next(); - let request = Payload::Request(Request::new(id, request.params)); + let request = Payload::Request(rpc::Request::new(id, request.params)); let serialized = serde_json::to_string(&request); match serialized { Ok(data) => match self.requests.entry(id) { Entry::Occupied(_) => { - tx.send(Err(Error::DuplicateRequestId)).ok(); + tx.send(Err(ClientError::DuplicateRequestId)).ok(); } Entry::Vacant(entry) => { @@ -123,7 +123,7 @@ impl ClientStream { }, Err(err) => { - tx.send(Err(Error::Serialization(err))).ok(); + tx.send(Err(ClientError::Serialization(err))).ok(); } } } @@ -132,7 +132,7 @@ impl ClientStream { /// returning a future that resolves with the response. pub fn send(&mut self, request: T) -> ResponseFuture where - T: RequestPayload, + T: ServiceRequest, { let (request, response) = create_request(request); self.send_raw(request); @@ -140,7 +140,7 @@ impl ClientStream { } /// Closes the connection. - pub async fn close(&mut self, frame: Option>) -> Result<(), Error> { + pub async fn close(&mut self, frame: Option>) -> Result<(), ClientError> { self.close_frame = frame.clone(); self.socket .close(frame) @@ -156,7 +156,9 @@ impl ClientStream { Ok(payload) => payload, Err(err) => { - return Some(StreamEvent::InboundError(Error::Deserialization(err))) + return Some(StreamEvent::InboundError(ClientError::Deserialization( + err, + ))) } }; @@ -172,7 +174,7 @@ impl ClientStream { ) } - _ => StreamEvent::InboundError(Error::InvalidRequestType), + _ => StreamEvent::InboundError(ClientError::InvalidRequestType), }; Some(event) @@ -183,25 +185,21 @@ impl ClientStream { if id.is_zero() { return match response { - Response::Error(response) => { - Some(StreamEvent::InboundError(Error::Rpc { - code: response.error.code, - message: response.error.message, - })) - } + Response::Error(response) => Some(StreamEvent::InboundError( + ClientError::from(response.error), + )), - Response::Success(_) => { - Some(StreamEvent::InboundError(Error::InvalidResponseId)) - } + Response::Success(_) => Some(StreamEvent::InboundError( + ClientError::InvalidResponseId, + )), }; } if let Some(tx) = self.requests.remove(&id) { let result = match response { - Response::Error(response) => Err(Error::Rpc { - code: response.error.code, - message: response.error.message, - }), + Response::Error(response) => { + Err(ClientError::from(response.error)) + } Response::Success(response) => Ok(response.result), }; @@ -215,7 +213,7 @@ impl ClientStream { None } else { - Some(StreamEvent::InboundError(Error::InvalidResponseId)) + Some(StreamEvent::InboundError(ClientError::InvalidResponseId)) } } } diff --git a/relay_rpc/Cargo.toml b/relay_rpc/Cargo.toml index f25c10d..c8fa18b 100644 --- a/relay_rpc/Cargo.toml +++ b/relay_rpc/Cargo.toml @@ -52,6 +52,7 @@ alloy-json-rpc = { git = "https://github.com/alloy-rs/alloy.git", rev = "e6f98e1 alloy-json-abi = { version = "0.6.2", optional = true } alloy-sol-types = { version = "0.6.2", optional = true } alloy-primitives = { version = "0.6.2", optional = true } +strum = { version = "0.26", features = ["strum_macros", "derive"] } [dev-dependencies] tokio = { version = "1.35.1", features = ["test-util", "macros"] } diff --git a/relay_rpc/src/rpc.rs b/relay_rpc/src/rpc.rs index e50564d..b9c1c4d 100644 --- a/relay_rpc/src/rpc.rs +++ b/relay_rpc/src/rpc.rs @@ -1,16 +1,14 @@ //! The crate exports common types used when interacting with messages between //! clients. This also includes communication over HTTP between relays. -pub use watch::*; use { - crate::{ - domain::{DecodingError, DidKey, MessageId, SubscriptionId, Topic}, - jwt::JwtError, - }, + crate::domain::{DidKey, MessageId, SubscriptionId, Topic}, serde::{de::DeserializeOwned, Deserialize, Serialize}, std::{fmt::Debug, sync::Arc}, }; +pub use {error::*, watch::*}; +pub mod error; pub mod msg_id; #[cfg(test)] mod tests; @@ -37,96 +35,6 @@ pub const MAX_FETCH_BATCH_SIZE: usize = 500; /// See pub const MAX_RECEIVE_BATCH_SIZE: usize = 500; -type BoxError = Box; - -/// Errors covering payload validation problems. -#[derive(Debug, Clone, thiserror::Error, PartialEq, Eq)] -pub enum ValidationError { - #[error("Topic decoding failed: {0}")] - TopicDecoding(DecodingError), - - #[error("Subscription ID decoding failed: {0}")] - SubscriptionIdDecoding(DecodingError), - - #[error("Invalid request ID")] - RequestId, - - #[error("Invalid JSON RPC version")] - JsonRpcVersion, - - #[error("The batch contains too many items ({actual}). Maximum number of items is {limit}")] - BatchLimitExceeded { limit: usize, actual: usize }, - - #[error("The batch contains no items")] - BatchEmpty, -} - -/// Errors caught while processing the request. These are meant to be serialized -/// into [`ErrorResponse`], and should be specific enough for the clients to -/// make sense of the problem. -#[derive(Debug, thiserror::Error)] -pub enum GenericError { - #[error("Authorization error: {0}")] - Authorization(BoxError), - - #[error("Too many requests")] - TooManyRequests, - - /// Request parameters validation failed. - #[error("Request validation error: {0}")] - Validation(#[from] ValidationError), - - /// Request/response serialization error. - #[error("Serialization failed: {0}")] - Serialization(#[from] serde_json::Error), - - /// An unsupported JSON RPC method. - #[error("Unsupported request method")] - RequestMethod, - - /// Generic request-specific error, which could not be caught by the request - /// validation. - #[error("Failed to process request: {0}")] - Request(BoxError), - - /// Internal server error. These are not request-specific, but should not - /// normally happen if the relay is fully operational. - #[error("Internal error: {0}")] - Other(BoxError), -} - -impl GenericError { - /// The error code. These are the standard JSONRPC error codes. The Relay - /// specific errors are in 3000-4999 range to align with the websocket close - /// codes. - pub fn code(&self) -> i32 { - match self { - Self::Authorization(_) => 3000, - Self::TooManyRequests => 3001, - Self::Serialization(_) => -32700, - Self::Validation(_) => -32602, - Self::RequestMethod => -32601, - Self::Request(_) => -32000, - Self::Other(_) => -32603, - } - } -} - -impl From for ErrorData -where - T: Into, -{ - fn from(value: T) -> Self { - let value = value.into(); - - ErrorData { - code: value.code(), - message: value.to_string(), - data: None, - } - } -} - pub trait Serializable: Debug + Clone + PartialEq + Eq + Serialize + DeserializeOwned + Send + Sync + 'static { @@ -138,15 +46,15 @@ impl Serializable for T where /// Trait that adds validation capabilities and strong typing to errors and /// successful responses. Implemented for all possible RPC request types. -pub trait RequestPayload: Serializable { +pub trait ServiceRequest: Serializable { /// The error representing a failed request. - type Error: Into + Send + 'static; + type Error: ServiceError; /// The type of a successful response. type Response: Serializable; /// Validates the request parameters. - fn validate(&self) -> Result<(), ValidationError> { + fn validate(&self) -> Result<(), PayloadError> { Ok(()) } @@ -174,7 +82,7 @@ impl Payload { } } - pub fn validate(&self) -> Result<(), ValidationError> { + pub fn validate(&self) -> Result<(), PayloadError> { match self { Self::Request(request) => request.validate(), Self::Response(response) => response.validate(), @@ -211,7 +119,7 @@ impl Response { } /// Validates the response parameters. - pub fn validate(&self) -> Result<(), ValidationError> { + pub fn validate(&self) -> Result<(), PayloadError> { match self { Self::Success(response) => response.validate(), Self::Error(response) => response.validate(), @@ -243,9 +151,9 @@ impl SuccessfulResponse { } /// Validates the parameters. - pub fn validate(&self) -> Result<(), ValidationError> { + pub fn validate(&self) -> Result<(), PayloadError> { if self.jsonrpc.as_ref() != JSON_RPC_VERSION_STR { - Err(ValidationError::JsonRpcVersion) + Err(PayloadError::InvalidJsonRpcVersion) } else { // We can't really validate `serde_json::Value` without knowing the expected // value type. @@ -269,18 +177,18 @@ pub struct ErrorResponse { impl ErrorResponse { /// Create a new instance. - pub fn new(id: MessageId, error: ErrorData) -> Self { + pub fn new(id: MessageId, error: impl Into) -> Self { Self { id, jsonrpc: JSON_RPC_VERSION.clone(), - error, + error: error.into(), } } /// Validates the parameters. - pub fn validate(&self) -> Result<(), ValidationError> { + pub fn validate(&self) -> Result<(), PayloadError> { if self.jsonrpc.as_ref() != JSON_RPC_VERSION_STR { - Err(ValidationError::JsonRpcVersion) + Err(PayloadError::InvalidJsonRpcVersion) } else { Ok(()) } @@ -301,6 +209,12 @@ pub struct ErrorData { pub data: Option, } +#[derive(Debug, thiserror::Error, strum::EnumString, strum::IntoStaticStr, PartialEq, Eq)] +pub enum SubscriptionError { + #[error("Subscriber limit exceeded")] + SubscriberLimitExceeded, +} + /// Data structure representing subscribe request params. #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] pub struct Subscribe { @@ -313,14 +227,14 @@ pub struct Subscribe { pub block: bool, } -impl RequestPayload for Subscribe { - type Error = GenericError; +impl ServiceRequest for Subscribe { + type Error = SubscriptionError; type Response = SubscriptionId; - fn validate(&self) -> Result<(), ValidationError> { + fn validate(&self) -> Result<(), PayloadError> { self.topic .decode() - .map_err(ValidationError::TopicDecoding)?; + .map_err(|_| PayloadError::InvalidTopic)?; Ok(()) } @@ -341,14 +255,14 @@ pub struct Unsubscribe { pub subscription_id: SubscriptionId, } -impl RequestPayload for Unsubscribe { - type Error = GenericError; +impl ServiceRequest for Unsubscribe { + type Error = SubscriptionError; type Response = bool; - fn validate(&self) -> Result<(), ValidationError> { + fn validate(&self) -> Result<(), PayloadError> { self.topic .decode() - .map_err(ValidationError::TopicDecoding)?; + .map_err(|_| PayloadError::InvalidTopic)?; // FIXME: Subscription ID validation is currently disabled, since SDKs do not // use the actual IDs generated by the relay, and instead send some randomized @@ -374,14 +288,14 @@ pub struct FetchMessages { pub topic: Topic, } -impl RequestPayload for FetchMessages { +impl ServiceRequest for FetchMessages { type Error = GenericError; type Response = FetchResponse; - fn validate(&self) -> Result<(), ValidationError> { + fn validate(&self) -> Result<(), PayloadError> { self.topic .decode() - .map_err(ValidationError::TopicDecoding)?; + .map_err(|_| PayloadError::InvalidTopic)?; Ok(()) } @@ -415,26 +329,23 @@ pub struct BatchSubscribe { pub block: bool, } -impl RequestPayload for BatchSubscribe { - type Error = GenericError; +impl ServiceRequest for BatchSubscribe { + type Error = SubscriptionError; type Response = Vec; - fn validate(&self) -> Result<(), ValidationError> { + fn validate(&self) -> Result<(), PayloadError> { let batch_size = self.topics.len(); if batch_size == 0 { - return Err(ValidationError::BatchEmpty); + return Err(PayloadError::BatchEmpty); } if batch_size > MAX_SUBSCRIPTION_BATCH_SIZE { - return Err(ValidationError::BatchLimitExceeded { - limit: MAX_SUBSCRIPTION_BATCH_SIZE, - actual: batch_size, - }); + return Err(PayloadError::BatchLimitExceeded); } for topic in &self.topics { - topic.decode().map_err(ValidationError::TopicDecoding)?; + topic.decode().map_err(|_| PayloadError::InvalidTopic)?; } Ok(()) @@ -452,22 +363,19 @@ pub struct BatchUnsubscribe { pub subscriptions: Vec, } -impl RequestPayload for BatchUnsubscribe { - type Error = GenericError; +impl ServiceRequest for BatchUnsubscribe { + type Error = SubscriptionError; type Response = bool; - fn validate(&self) -> Result<(), ValidationError> { + fn validate(&self) -> Result<(), PayloadError> { let batch_size = self.subscriptions.len(); if batch_size == 0 { - return Err(ValidationError::BatchEmpty); + return Err(PayloadError::BatchEmpty); } if batch_size > MAX_SUBSCRIPTION_BATCH_SIZE { - return Err(ValidationError::BatchLimitExceeded { - limit: MAX_SUBSCRIPTION_BATCH_SIZE, - actual: batch_size, - }); + return Err(PayloadError::BatchLimitExceeded); } for sub in &self.subscriptions { @@ -489,26 +397,23 @@ pub struct BatchFetchMessages { pub topics: Vec, } -impl RequestPayload for BatchFetchMessages { +impl ServiceRequest for BatchFetchMessages { type Error = GenericError; type Response = FetchResponse; - fn validate(&self) -> Result<(), ValidationError> { + fn validate(&self) -> Result<(), PayloadError> { let batch_size = self.topics.len(); if batch_size == 0 { - return Err(ValidationError::BatchEmpty); + return Err(PayloadError::BatchEmpty); } if batch_size > MAX_FETCH_BATCH_SIZE { - return Err(ValidationError::BatchLimitExceeded { - limit: MAX_FETCH_BATCH_SIZE, - actual: batch_size, - }); + return Err(PayloadError::BatchLimitExceeded); } for topic in &self.topics { - topic.decode().map_err(ValidationError::TopicDecoding)?; + topic.decode().map_err(|_| PayloadError::InvalidTopic)?; } Ok(()) @@ -536,29 +441,26 @@ pub struct BatchReceiveMessages { pub receipts: Vec, } -impl RequestPayload for BatchReceiveMessages { +impl ServiceRequest for BatchReceiveMessages { type Error = GenericError; type Response = bool; - fn validate(&self) -> Result<(), ValidationError> { + fn validate(&self) -> Result<(), PayloadError> { let batch_size = self.receipts.len(); if batch_size == 0 { - return Err(ValidationError::BatchEmpty); + return Err(PayloadError::BatchEmpty); } if batch_size > MAX_RECEIVE_BATCH_SIZE { - return Err(ValidationError::BatchLimitExceeded { - limit: MAX_RECEIVE_BATCH_SIZE, - actual: batch_size, - }); + return Err(PayloadError::BatchLimitExceeded); } for receipt in &self.receipts { receipt .topic .decode() - .map_err(ValidationError::TopicDecoding)?; + .map_err(|_| PayloadError::InvalidTopic)?; } Ok(()) @@ -626,7 +528,7 @@ impl Publish { } } -#[derive(Debug, thiserror::Error)] +#[derive(Debug, thiserror::Error, strum::EnumString, strum::IntoStaticStr, PartialEq, Eq)] pub enum PublishError { #[error("TTL too short")] TtlTooShort, @@ -634,24 +536,18 @@ pub enum PublishError { #[error("TTL too long")] TtlTooLong, - #[error("{0}")] - Other(BoxError), -} - -impl From for GenericError { - fn from(err: PublishError) -> Self { - Self::Request(Box::new(err)) - } + #[error("Mailbox limit exceeded")] + MailboxLimitExceeded, } -impl RequestPayload for Publish { +impl ServiceRequest for Publish { type Error = PublishError; type Response = bool; - fn validate(&self) -> Result<(), ValidationError> { + fn validate(&self) -> Result<(), PayloadError> { self.topic .decode() - .map_err(ValidationError::TopicDecoding)?; + .map_err(|_| PayloadError::InvalidTopic)?; Ok(()) } @@ -668,7 +564,13 @@ where *x == Default::default() } -#[derive(Debug, thiserror::Error)] +#[derive(Debug, thiserror::Error, strum::EnumString, strum::IntoStaticStr, PartialEq, Eq)] +pub enum GenericError { + #[error("Unknown error")] + Unknown, +} + +#[derive(Debug, thiserror::Error, strum::EnumString, strum::IntoStaticStr, PartialEq, Eq)] pub enum WatchError { #[error("Invalid TTL")] InvalidTtl, @@ -679,17 +581,11 @@ pub enum WatchError { #[error("Webhook URL is invalid or too long")] InvalidWebhookUrl, - #[error("Failed to decode JWT: {0}")] - Jwt(#[from] JwtError), + #[error("Invalid action")] + InvalidAction, - #[error("{0}")] - Other(BoxError), -} - -impl From for GenericError { - fn from(err: WatchError) -> Self { - Self::Request(Box::new(err)) - } + #[error("Invalid JWT")] + InvalidJwt, } #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] @@ -707,11 +603,11 @@ pub struct WatchRegister { pub register_auth: String, } -impl RequestPayload for WatchRegister { +impl ServiceRequest for WatchRegister { type Error = WatchError; type Response = WatchRegisterResponse; - fn validate(&self) -> Result<(), ValidationError> { + fn validate(&self) -> Result<(), PayloadError> { Ok(()) } @@ -728,11 +624,11 @@ pub struct WatchUnregister { pub unregister_auth: String, } -impl RequestPayload for WatchUnregister { +impl ServiceRequest for WatchUnregister { type Error = WatchError; type Response = bool; - fn validate(&self) -> Result<(), ValidationError> { + fn validate(&self) -> Result<(), PayloadError> { Ok(()) } @@ -751,19 +647,19 @@ pub struct Subscription { pub data: SubscriptionData, } -impl RequestPayload for Subscription { +impl ServiceRequest for Subscription { type Error = GenericError; type Response = bool; - fn validate(&self) -> Result<(), ValidationError> { + fn validate(&self) -> Result<(), PayloadError> { self.id .decode() - .map_err(ValidationError::SubscriptionIdDecoding)?; + .map_err(|_| PayloadError::InvalidSubscriptionId)?; self.data .topic .decode() - .map_err(ValidationError::TopicDecoding)?; + .map_err(|_| PayloadError::InvalidTopic)?; Ok(()) } @@ -872,13 +768,13 @@ impl Request { } /// Validates the request payload. - pub fn validate(&self) -> Result<(), ValidationError> { + pub fn validate(&self) -> Result<(), PayloadError> { if !self.id.validate() { - return Err(ValidationError::RequestId); + return Err(PayloadError::InvalidRequestId); } if self.jsonrpc.as_ref() != JSON_RPC_VERSION_STR { - return Err(ValidationError::JsonRpcVersion); + return Err(PayloadError::InvalidJsonRpcVersion); } match &self.params { diff --git a/relay_rpc/src/rpc/error.rs b/relay_rpc/src/rpc/error.rs new file mode 100644 index 0000000..d7f5c4c --- /dev/null +++ b/relay_rpc/src/rpc/error.rs @@ -0,0 +1,186 @@ +use { + super::ErrorData, + std::fmt::{Debug, Display}, +}; + +/// Provides serialization to and from string tags. This has a blanket +/// implementation for all error types that derive [`strum::EnumString`] and +/// [`strum::IntoStaticStr`]. +pub trait ServiceError: Sized + Debug + Display + PartialEq + Send + 'static { + fn from_tag(tag: &str) -> Result; + + fn tag(&self) -> &'static str; +} + +impl ServiceError for T +where + T: for<'a> TryFrom<&'a str> + Debug + Display + PartialEq + Send + 'static, + for<'a> &'static str: From<&'a T>, +{ + fn from_tag(tag: &str) -> Result { + tag.try_into().map_err(|_| InvalidErrorData) + } + + fn tag(&self) -> &'static str { + self.into() + } +} + +#[derive(Debug, thiserror::Error, strum::EnumString, strum::IntoStaticStr, PartialEq, Eq)] +pub enum AuthError { + #[error("Project not found")] + ProjectNotFound, + + #[error("Project ID not specified")] + ProjectIdNotSpecified, + + #[error("Project inactive")] + ProjectInactive, + + #[error("Origin not allowed")] + OriginNotAllowed, + + #[error("Invalid JWT")] + InvalidJwt, + + #[error("Missing JWT")] + MissingJwt, + + #[error("Country blocked")] + CountryBlocked, +} + +/// Request payload validation problems. +#[derive( + Debug, Clone, thiserror::Error, strum::EnumString, strum::IntoStaticStr, PartialEq, Eq, +)] +pub enum PayloadError { + #[error("Invalid request method")] + InvalidMethod, + + #[error("Invalid request parameters")] + InvalidParams, + + #[error("Payload size exceeded")] + PayloadSizeExceeded, + + #[error("Topic decoding failed")] + InvalidTopic, + + #[error("Subscription ID decoding failed")] + InvalidSubscriptionId, + + #[error("Invalid request ID")] + InvalidRequestId, + + #[error("Invalid JSON RPC version")] + InvalidJsonRpcVersion, + + #[error("The batch contains too many items")] + BatchLimitExceeded, + + #[error("The batch contains no items")] + BatchEmpty, + + #[error("Failed to deserialize request")] + Serialization, +} + +#[derive(Debug, thiserror::Error, strum::EnumString, strum::IntoStaticStr, PartialEq, Eq)] +pub enum InternalError { + #[error("Storage operation failed")] + StorageError, + + #[error("Failed to serialize response")] + Serialization, + + #[error("Internal error")] + Unknown, +} + +/// Errors caught while processing the request. These are meant to be serialized +/// into [`super::ErrorResponse`], and should be specific enough for the clients +/// to make sense of the problem. +#[derive(Debug, thiserror::Error, strum::IntoStaticStr, PartialEq, Eq)] +pub enum Error { + #[error("Auth error: {0}")] + Auth(#[from] AuthError), + + #[error("Invalid payload: {0}")] + Payload(#[from] PayloadError), + + #[error("Request handler error: {0}")] + Handler(T), + + #[error("Internal error: {0}")] + Internal(#[from] InternalError), + + #[error("Too many requests")] + TooManyRequests, +} + +impl Error { + pub fn code(&self) -> i32 { + match self { + Self::Auth(_) => CODE_AUTH, + Self::TooManyRequests => CODE_TOO_MANY_REQUESTS, + Self::Payload(_) => CODE_PAYLOAD, + Self::Handler(_) => CODE_HANDLER, + Self::Internal(_) => CODE_INTERNAL, + } + } + + pub fn tag(&self) -> &'static str { + match &self { + Self::Auth(err) => err.tag(), + Self::Payload(err) => err.tag(), + Self::Handler(err) => err.tag(), + Self::Internal(err) => err.tag(), + Self::TooManyRequests => self.into(), + } + } +} + +pub const CODE_AUTH: i32 = 3000; +pub const CODE_TOO_MANY_REQUESTS: i32 = 3001; +pub const CODE_PAYLOAD: i32 = -32600; +pub const CODE_HANDLER: i32 = -32000; +pub const CODE_INTERNAL: i32 = -32603; + +#[derive(Debug, thiserror::Error)] +#[error("Invalid error data")] +pub struct InvalidErrorData; + +impl TryFrom for Error { + type Error = InvalidErrorData; + + fn try_from(err: ErrorData) -> Result { + let tag = &err.data; + + let err = match err.code { + CODE_AUTH => Error::Auth(try_parse_error(tag)?), + CODE_TOO_MANY_REQUESTS => Error::TooManyRequests, + CODE_PAYLOAD => Error::Payload(try_parse_error(tag)?), + CODE_HANDLER => Error::Handler(try_parse_error(tag)?), + CODE_INTERNAL => Error::Internal(try_parse_error(tag)?), + _ => return Err(InvalidErrorData), + }; + + Ok(err) + } +} + +#[inline] +fn try_parse_error(tag: &Option) -> Result { + tag.as_deref().ok_or(InvalidErrorData).map(T::from_tag)? +} + +impl From> for ErrorData { + fn from(err: Error) -> Self { + Self { + code: err.code(), + message: err.to_string(), + data: Some(err.tag().to_owned()), + } + } +} diff --git a/relay_rpc/src/rpc/tests.rs b/relay_rpc/src/rpc/tests.rs index 074c492..0345b73 100644 --- a/relay_rpc/src/rpc/tests.rs +++ b/relay_rpc/src/rpc/tests.rs @@ -281,7 +281,7 @@ fn validation() { prompt: false, }), }; - assert_eq!(request.validate(), Err(ValidationError::RequestId)); + assert_eq!(request.validate(), Err(PayloadError::InvalidRequestId)); // Invalid JSONRPC version. let request = Request { @@ -295,7 +295,7 @@ fn validation() { prompt: false, }), }; - assert_eq!(request.validate(), Err(ValidationError::JsonRpcVersion)); + assert_eq!(request.validate(), Err(PayloadError::InvalidJsonRpcVersion)); // Publish: valid. let request = Request { @@ -323,10 +323,7 @@ fn validation() { prompt: false, }), }; - assert_eq!( - request.validate(), - Err(ValidationError::TopicDecoding(DecodingError::Length)) - ); + assert_eq!(request.validate(), Err(PayloadError::InvalidTopic)); // Subscribe: valid. let request = Request { @@ -348,10 +345,7 @@ fn validation() { block: false, }), }; - assert_eq!( - request.validate(), - Err(ValidationError::TopicDecoding(DecodingError::Length)) - ); + assert_eq!(request.validate(), Err(PayloadError::InvalidTopic)); // Unsubscribe: valid. let request = Request { @@ -373,10 +367,7 @@ fn validation() { subscription_id: subscription_id.clone(), }), }; - assert_eq!( - request.validate(), - Err(ValidationError::TopicDecoding(DecodingError::Length)) - ); + assert_eq!(request.validate(), Err(PayloadError::InvalidTopic)); // Fetch: valid. let request = Request { @@ -396,10 +387,7 @@ fn validation() { topic: Topic::from("invalid"), }), }; - assert_eq!( - request.validate(), - Err(ValidationError::TopicDecoding(DecodingError::Length)) - ); + assert_eq!(request.validate(), Err(PayloadError::InvalidTopic)); // Subscription: valid. let request = Request { @@ -431,12 +419,7 @@ fn validation() { }, }), }; - assert_eq!( - request.validate(), - Err(ValidationError::SubscriptionIdDecoding( - DecodingError::Length - )) - ); + assert_eq!(request.validate(), Err(PayloadError::InvalidSubscriptionId)); // Subscription: invalid topic. let request = Request { @@ -452,10 +435,7 @@ fn validation() { }, }), }; - assert_eq!( - request.validate(), - Err(ValidationError::TopicDecoding(DecodingError::Length)) - ); + assert_eq!(request.validate(), Err(PayloadError::InvalidTopic)); // Batch subscription: valid. let request = Request { @@ -477,7 +457,7 @@ fn validation() { block: false, }), }; - assert_eq!(request.validate(), Err(ValidationError::BatchEmpty)); + assert_eq!(request.validate(), Err(PayloadError::BatchEmpty)); // Batch subscription: too many items. let topics = (0..MAX_SUBSCRIPTION_BATCH_SIZE + 1) @@ -491,13 +471,7 @@ fn validation() { block: false, }), }; - assert_eq!( - request.validate(), - Err(ValidationError::BatchLimitExceeded { - limit: MAX_SUBSCRIPTION_BATCH_SIZE, - actual: MAX_SUBSCRIPTION_BATCH_SIZE + 1 - }) - ); + assert_eq!(request.validate(), Err(PayloadError::BatchLimitExceeded)); // Batch subscription: invalid topic. let request = Request { @@ -510,10 +484,7 @@ fn validation() { block: false, }), }; - assert_eq!( - request.validate(), - Err(ValidationError::TopicDecoding(DecodingError::Length)) - ); + assert_eq!(request.validate(), Err(PayloadError::InvalidTopic)); // Batch unsubscription: valid. let request = Request { @@ -536,7 +507,7 @@ fn validation() { subscriptions: vec![], }), }; - assert_eq!(request.validate(), Err(ValidationError::BatchEmpty)); + assert_eq!(request.validate(), Err(PayloadError::BatchEmpty)); // Batch unsubscription: too many items. let subscriptions = (0..MAX_SUBSCRIPTION_BATCH_SIZE + 1) @@ -550,13 +521,7 @@ fn validation() { jsonrpc: jsonrpc.clone(), params: Params::BatchUnsubscribe(BatchUnsubscribe { subscriptions }), }; - assert_eq!( - request.validate(), - Err(ValidationError::BatchLimitExceeded { - limit: MAX_SUBSCRIPTION_BATCH_SIZE, - actual: MAX_SUBSCRIPTION_BATCH_SIZE + 1 - }) - ); + assert_eq!(request.validate(), Err(PayloadError::BatchLimitExceeded)); // Batch unsubscription: invalid topic. let request = Request { @@ -571,10 +536,7 @@ fn validation() { }], }), }; - assert_eq!( - request.validate(), - Err(ValidationError::TopicDecoding(DecodingError::Length)) - ); + assert_eq!(request.validate(), Err(PayloadError::InvalidTopic)); // Batch fetch: valid. let request = Request { @@ -592,7 +554,7 @@ fn validation() { jsonrpc: jsonrpc.clone(), params: Params::BatchFetchMessages(BatchFetchMessages { topics: vec![] }), }; - assert_eq!(request.validate(), Err(ValidationError::BatchEmpty)); + assert_eq!(request.validate(), Err(PayloadError::BatchEmpty)); // Batch fetch: too many items. let topics = (0..MAX_SUBSCRIPTION_BATCH_SIZE + 1) @@ -603,13 +565,7 @@ fn validation() { jsonrpc: jsonrpc.clone(), params: Params::BatchFetchMessages(BatchFetchMessages { topics }), }; - assert_eq!( - request.validate(), - Err(ValidationError::BatchLimitExceeded { - limit: MAX_SUBSCRIPTION_BATCH_SIZE, - actual: MAX_SUBSCRIPTION_BATCH_SIZE + 1 - }) - ); + assert_eq!(request.validate(), Err(PayloadError::BatchLimitExceeded)); // Batch fetch: invalid topic. let request = Request { @@ -621,10 +577,7 @@ fn validation() { )], }), }; - assert_eq!( - request.validate(), - Err(ValidationError::TopicDecoding(DecodingError::Length)) - ); + assert_eq!(request.validate(), Err(PayloadError::InvalidTopic)); // Batch receive: valid. let request = Request { @@ -645,7 +598,7 @@ fn validation() { jsonrpc: jsonrpc.clone(), params: Params::BatchReceiveMessages(BatchReceiveMessages { receipts: vec![] }), }; - assert_eq!(request.validate(), Err(ValidationError::BatchEmpty)); + assert_eq!(request.validate(), Err(PayloadError::BatchEmpty)); // Batch receive: too many items. let receipts = (0..MAX_RECEIVE_BATCH_SIZE + 1) @@ -659,13 +612,7 @@ fn validation() { jsonrpc: jsonrpc.clone(), params: Params::BatchReceiveMessages(BatchReceiveMessages { receipts }), }; - assert_eq!( - request.validate(), - Err(ValidationError::BatchLimitExceeded { - limit: MAX_RECEIVE_BATCH_SIZE, - actual: MAX_RECEIVE_BATCH_SIZE + 1 - }) - ); + assert_eq!(request.validate(), Err(PayloadError::BatchLimitExceeded)); // Batch receive: invalid topic. let request = Request { @@ -680,8 +627,71 @@ fn validation() { }], }), }; + assert_eq!(request.validate(), Err(PayloadError::InvalidTopic)); +} + +#[test] +fn error_tags() { + // Validate hardcoded string tags, so that we don't accidentally break + // compatibility with other SDKs as a result of refactoring. + + assert_eq!( + Error::::TooManyRequests.tag(), + "TooManyRequests" + ); + + assert_eq!( + SubscriptionError::SubscriberLimitExceeded.tag(), + "SubscriberLimitExceeded" + ); + + assert_eq!(PublishError::TtlTooShort.tag(), "TtlTooShort"); + assert_eq!(PublishError::TtlTooLong.tag(), "TtlTooLong"); assert_eq!( - request.validate(), - Err(ValidationError::TopicDecoding(DecodingError::Length)) + PublishError::MailboxLimitExceeded.tag(), + "MailboxLimitExceeded" ); + + assert_eq!(GenericError::Unknown.tag(), "Unknown"); + + assert_eq!(WatchError::InvalidTtl.tag(), "InvalidTtl"); + assert_eq!(WatchError::InvalidServiceUrl.tag(), "InvalidServiceUrl"); + assert_eq!(WatchError::InvalidWebhookUrl.tag(), "InvalidWebhookUrl"); + assert_eq!(WatchError::InvalidAction.tag(), "InvalidAction"); + assert_eq!(WatchError::InvalidJwt.tag(), "InvalidJwt"); + + assert_eq!(AuthError::ProjectNotFound.tag(), "ProjectNotFound"); + assert_eq!( + AuthError::ProjectIdNotSpecified.tag(), + "ProjectIdNotSpecified" + ); + assert_eq!(AuthError::ProjectInactive.tag(), "ProjectInactive"); + assert_eq!(AuthError::OriginNotAllowed.tag(), "OriginNotAllowed"); + assert_eq!(AuthError::InvalidJwt.tag(), "InvalidJwt"); + assert_eq!(AuthError::MissingJwt.tag(), "MissingJwt"); + assert_eq!(AuthError::CountryBlocked.tag(), "CountryBlocked"); + + assert_eq!(PayloadError::InvalidMethod.tag(), "InvalidMethod"); + assert_eq!(PayloadError::InvalidParams.tag(), "InvalidParams"); + assert_eq!( + PayloadError::PayloadSizeExceeded.tag(), + "PayloadSizeExceeded" + ); + assert_eq!(PayloadError::InvalidTopic.tag(), "InvalidTopic"); + assert_eq!( + PayloadError::InvalidSubscriptionId.tag(), + "InvalidSubscriptionId" + ); + assert_eq!(PayloadError::InvalidRequestId.tag(), "InvalidRequestId"); + assert_eq!( + PayloadError::InvalidJsonRpcVersion.tag(), + "InvalidJsonRpcVersion" + ); + assert_eq!(PayloadError::BatchLimitExceeded.tag(), "BatchLimitExceeded"); + assert_eq!(PayloadError::BatchEmpty.tag(), "BatchEmpty"); + assert_eq!(PayloadError::Serialization.tag(), "Serialization"); + + assert_eq!(InternalError::StorageError.tag(), "StorageError"); + assert_eq!(InternalError::Serialization.tag(), "Serialization"); + assert_eq!(InternalError::Unknown.tag(), "Unknown"); }