From 8a948288979a12069aab4b20b58cb09b326ce96e Mon Sep 17 00:00:00 2001 From: Emil Ernerfeldt Date: Mon, 26 Feb 2024 12:52:28 +0100 Subject: [PATCH] Add `Options` for controlling max frame size of incoming messages --- ewebsock/src/lib.rs | 41 +++++++++++++++---- ewebsock/src/native_tungstenite.rs | 50 +++++++++++++++--------- ewebsock/src/native_tungstenite_tokio.rs | 29 ++++++++++---- ewebsock/src/tungstenite_common.rs | 16 ++++++++ ewebsock/src/web.rs | 12 ++++-- example_app/src/app.rs | 2 +- 6 files changed, 110 insertions(+), 40 deletions(-) create mode 100644 ewebsock/src/tungstenite_common.rs diff --git a/ewebsock/src/lib.rs b/ewebsock/src/lib.rs index 1c76ff5..5917b8a 100644 --- a/ewebsock/src/lib.rs +++ b/ewebsock/src/lib.rs @@ -2,7 +2,8 @@ //! //! Usage: //! ``` no_run -//! let (mut sender, receiver) = ewebsock::connect("ws://example.com").unwrap(); +//! let options = ewebsock::Options::default(); +//! let (mut sender, receiver) = ewebsock::connect("ws://example.com", options).unwrap(); //! sender.send(ewebsock::WsMessage::Text("Hello!".into())); //! while let Some(event) = receiver.try_recv() { //! println!("Received {:?}", event); @@ -31,6 +32,9 @@ mod native_tungstenite_tokio; #[cfg(feature = "tokio")] pub use native_tungstenite_tokio::*; +#[cfg(not(target_arch = "wasm32"))] +mod tungstenite_common; + #[cfg(target_arch = "wasm32")] mod web; @@ -117,6 +121,26 @@ pub type Result = std::result::Result; pub(crate) type EventHandler = Box std::ops::ControlFlow<()>>; +/// Options for a connection. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Options { + /// The maximum size of a single incoming message frame, in bytes. + /// + /// The primary reason for setting this to something other than [`usize::MAX`] is + /// to prevent a malicious server from eating up all your RAM. + /// + /// Ignored on Web. + pub max_incoming_frame_size: usize, +} + +impl Default for Options { + fn default() -> Self { + Self { + max_incoming_frame_size: 64 * 1024 * 1024, + } + } +} + /// Connect to the given URL, and return a sender and receiver. /// /// This is a wrapper around [`ws_connect`]. @@ -127,9 +151,9 @@ pub(crate) type EventHandler = Box std::ops::ControlFl /// /// See also the [`connect_with_wakeup`] function, /// and the more advanced [`ws_connect`]. -pub fn connect(url: impl Into) -> Result<(WsSender, WsReceiver)> { +pub fn connect(url: impl Into, options: Options) -> Result<(WsSender, WsReceiver)> { let (ws_receiver, on_event) = WsReceiver::new(); - let ws_sender = ws_connect(url.into(), on_event)?; + let ws_sender = ws_connect(url.into(), options, on_event)?; Ok((ws_sender, ws_receiver)) } @@ -146,10 +170,11 @@ pub fn connect(url: impl Into) -> Result<(WsSender, WsReceiver)> { /// Note that you have to wait for [`WsEvent::Opened`] before sending messages. pub fn connect_with_wakeup( url: impl Into, + options: Options, wake_up: impl Fn() + Send + Sync + 'static, ) -> Result<(WsSender, WsReceiver)> { let (receiver, on_event) = WsReceiver::new_with_callback(wake_up); - let sender = ws_connect(url.into(), on_event)?; + let sender = ws_connect(url.into(), options, on_event)?; Ok((sender, receiver)) } @@ -160,8 +185,8 @@ pub fn connect_with_wakeup( /// # Errors /// * On native: failure to spawn a thread. /// * On web: failure to use `WebSocket` API. -pub fn ws_connect(url: String, on_event: EventHandler) -> Result { - ws_connect_impl(url, on_event) +pub fn ws_connect(url: String, options: Options, on_event: EventHandler) -> Result { + ws_connect_impl(url, options, on_event) } /// Connect and call the given event handler on each received event. @@ -174,6 +199,6 @@ pub fn ws_connect(url: String, on_event: EventHandler) -> Result { /// # Errors /// * On native: failure to spawn receiver thread. /// * On web: failure to use `WebSocket` API. -pub fn ws_receive(url: String, on_event: EventHandler) -> Result<()> { - ws_receive_impl(url, on_event) +pub fn ws_receive(url: String, options: Options, on_event: EventHandler) -> Result<()> { + ws_receive_impl(url, options, on_event) } diff --git a/ewebsock/src/native_tungstenite.rs b/ewebsock/src/native_tungstenite.rs index 55f5f9c..66821f2 100644 --- a/ewebsock/src/native_tungstenite.rs +++ b/ewebsock/src/native_tungstenite.rs @@ -2,7 +2,7 @@ use std::sync::mpsc::{Receiver, TryRecvError}; -use crate::{EventHandler, Result, WsEvent, WsMessage}; +use crate::{EventHandler, Options, Result, WsEvent, WsMessage}; /// This is how you send [`WsMessage`]s to the server. /// @@ -47,11 +47,11 @@ impl WsSender { } } -pub(crate) fn ws_receive_impl(url: String, on_event: EventHandler) -> Result<()> { +pub(crate) fn ws_receive_impl(url: String, options: Options, on_event: EventHandler) -> Result<()> { std::thread::Builder::new() .name("ewebsock".to_owned()) .spawn(move || { - if let Err(err) = ws_receiver_blocking(&url, &on_event) { + if let Err(err) = ws_receiver_blocking(&url, options, &on_event) { on_event(WsEvent::Error(err)); } else { log::debug!("WebSocket connection closed."); @@ -64,17 +64,21 @@ pub(crate) fn ws_receive_impl(url: String, on_event: EventHandler) -> Result<()> /// Connect and call the given event handler on each received event. /// -/// Blocking version of [`ws_receive`], only avilable on native. +/// Blocking version of [`ws_receive`], only available on native. /// /// # Errors /// All errors are returned to the caller, and NOT reported via `on_event`. -pub fn ws_receiver_blocking(url: &str, on_event: &EventHandler) -> Result<()> { - let (mut socket, response) = match tungstenite::connect(url) { - Ok(result) => result, - Err(err) => { - return Err(format!("Connect: {err}")); - } - }; +pub fn ws_receiver_blocking(url: &str, options: Options, on_event: &EventHandler) -> Result<()> { + let config = tungstenite::protocol::WebSocketConfig::from(options); + let max_redirects = 3; // tungstenite default + + let (mut socket, response) = + match tungstenite::client::connect_with_config(url, Some(config), max_redirects) { + Ok(result) => result, + Err(err) => { + return Err(format!("Connect: {err}")); + } + }; log::debug!("WebSocket HTTP response code: {}", response.status()); log::trace!( @@ -115,13 +119,17 @@ pub fn ws_receiver_blocking(url: &str, on_event: &EventHandler) -> Result<()> { } } -pub(crate) fn ws_connect_impl(url: String, on_event: EventHandler) -> Result { +pub(crate) fn ws_connect_impl( + url: String, + options: Options, + on_event: EventHandler, +) -> Result { let (tx, rx) = std::sync::mpsc::channel(); std::thread::Builder::new() .name("ewebsock".to_owned()) .spawn(move || { - if let Err(err) = ws_connect_blocking(&url, &on_event, &rx) { + if let Err(err) = ws_connect_blocking(&url, options, &on_event, &rx) { on_event(WsEvent::Error(err)); } else { log::debug!("WebSocket connection closed."); @@ -140,15 +148,19 @@ pub(crate) fn ws_connect_impl(url: String, on_event: EventHandler) -> Result, ) -> Result<()> { - let (mut socket, response) = match tungstenite::connect(url) { - Ok(result) => result, - Err(err) => { - return Err(format!("Connect: {err}")); - } - }; + let config = tungstenite::protocol::WebSocketConfig::from(options); + let max_redirects = 3; // tungstenite default + let (mut socket, response) = + match tungstenite::client::connect_with_config(url, Some(config), max_redirects) { + Ok(result) => result, + Err(err) => { + return Err(format!("Connect: {err}")); + } + }; log::debug!("WebSocket HTTP response code: {}", response.status()); log::trace!( diff --git a/ewebsock/src/native_tungstenite_tokio.rs b/ewebsock/src/native_tungstenite_tokio.rs index c929e72..88386e0 100644 --- a/ewebsock/src/native_tungstenite_tokio.rs +++ b/ewebsock/src/native_tungstenite_tokio.rs @@ -1,4 +1,4 @@ -use crate::{EventHandler, Result, WsEvent, WsMessage}; +use crate::{EventHandler, Options, Result, WsEvent, WsMessage}; /// This is how you send [`WsMessage`]s to the server. /// @@ -45,12 +45,21 @@ impl WsSender { async fn ws_connect_async( url: String, + options: Options, outgoing_messages_stream: impl futures::Stream, on_event: EventHandler, ) { use futures::StreamExt as _; - let (ws_stream, _) = match tokio_tungstenite::connect_async(url).await { + let config = tungstenite::protocol::WebSocketConfig::from(options); + let disable_nagle = false; // God damn everyone who adds negations to the names of their variables + let (ws_stream, _) = match tokio_tungstenite::connect_async_with_config( + url, + Some(config), + disable_nagle, + ) + .await + { Ok(result) => result, Err(err) => { on_event(WsEvent::Error(err.to_string())); @@ -106,12 +115,16 @@ async fn ws_connect_async( } #[allow(clippy::unnecessary_wraps)] -pub(crate) fn ws_connect_impl(url: String, on_event: EventHandler) -> Result { - Ok(ws_connect_native(url, on_event)) +pub(crate) fn ws_connect_impl( + url: String, + options: Options, + on_event: EventHandler, +) -> Result { + Ok(ws_connect_native(url, options, on_event)) } /// Like [`ws_connect`], but cannot fail. Only available on native builds. -fn ws_connect_native(url: String, on_event: EventHandler) -> WsSender { +fn ws_connect_native(url: String, options: Options, on_event: EventHandler) -> WsSender { let (tx, mut rx) = tokio::sync::mpsc::channel(1000); let outgoing_messages_stream = async_stream::stream! { @@ -122,12 +135,12 @@ fn ws_connect_native(url: String, on_event: EventHandler) -> WsSender { }; tokio::spawn(async move { - ws_connect_async(url.clone(), outgoing_messages_stream, on_event).await; + ws_connect_async(url.clone(), options, outgoing_messages_stream, on_event).await; log::debug!("WS connection finished."); }); WsSender { tx: Some(tx) } } -pub(crate) fn ws_receive_impl(url: String, on_event: EventHandler) -> Result<()> { - ws_connect_impl(url, on_event).map(|sender| sender.forget()) +pub(crate) fn ws_receive_impl(url: String, options: Options, on_event: EventHandler) -> Result<()> { + ws_connect_impl(url, options, on_event).map(|sender| sender.forget()) } diff --git a/ewebsock/src/tungstenite_common.rs b/ewebsock/src/tungstenite_common.rs new file mode 100644 index 0000000..465122a --- /dev/null +++ b/ewebsock/src/tungstenite_common.rs @@ -0,0 +1,16 @@ +impl From for tungstenite::protocol::WebSocketConfig { + fn from(options: crate::Options) -> Self { + let crate::Options { + max_incoming_frame_size, + } = options; + + tungstenite::protocol::WebSocketConfig { + max_frame_size: if max_incoming_frame_size == usize::MAX { + None + } else { + Some(max_incoming_frame_size) + }, + ..Default::default() + } + } +} diff --git a/ewebsock/src/web.rs b/ewebsock/src/web.rs index 6f8e08a..f9dacfb 100644 --- a/ewebsock/src/web.rs +++ b/ewebsock/src/web.rs @@ -1,4 +1,4 @@ -use crate::{EventHandler, Result, WsEvent, WsMessage}; +use crate::{EventHandler, Options, Result, WsEvent, WsMessage}; #[allow(clippy::needless_pass_by_value)] fn string_from_js_value(s: wasm_bindgen::JsValue) -> String { @@ -63,11 +63,15 @@ impl WsSender { } } -pub(crate) fn ws_receive_impl(url: String, on_event: EventHandler) -> Result<()> { - ws_connect_impl(url, on_event).map(|sender| sender.forget()) +pub(crate) fn ws_receive_impl(url: String, options: Options, on_event: EventHandler) -> Result<()> { + ws_connect_impl(url, options, on_event).map(|sender| sender.forget()) } -pub(crate) fn ws_connect_impl(url: String, on_event: EventHandler) -> Result { +pub(crate) fn ws_connect_impl( + url: String, + _ignored_options: Options, + on_event: EventHandler, +) -> Result { // Based on https://rustwasm.github.io/wasm-bindgen/examples/websockets.html use wasm_bindgen::closure::Closure; diff --git a/example_app/src/app.rs b/example_app/src/app.rs index 2a53d60..6ce5926 100644 --- a/example_app/src/app.rs +++ b/example_app/src/app.rs @@ -61,7 +61,7 @@ impl eframe::App for ExampleApp { impl ExampleApp { fn connect(&mut self, ctx: egui::Context) { let wakeup = move || ctx.request_repaint(); // wake up UI thread on new message - match ewebsock::connect_with_wakeup(&self.url, wakeup) { + match ewebsock::connect_with_wakeup(&self.url, Default::default(), wakeup) { Ok((ws_sender, ws_receiver)) => { self.frontend = Some(FrontEnd::new(ws_sender, ws_receiver)); self.error.clear();