From 934de81863e320286191a3f90d45c61d9646c0f1 Mon Sep 17 00:00:00 2001 From: jprochazk Date: Thu, 24 Oct 2024 13:21:47 +0200 Subject: [PATCH] use `set_read_timeout` instead of sleep --- ewebsock/src/lib.rs | 10 +- ewebsock/src/native_tungstenite.rs | 169 ++++++++++++++--------------- 2 files changed, 87 insertions(+), 92 deletions(-) diff --git a/ewebsock/src/lib.rs b/ewebsock/src/lib.rs index ed9f54c..98fd4df 100644 --- a/ewebsock/src/lib.rs +++ b/ewebsock/src/lib.rs @@ -147,8 +147,12 @@ pub struct Options { /// Currently only supported on native. pub subprotocols: Vec, - /// Delay blocking in ms - default 10ms - pub delay_blocking: std::time::Duration, + /// Socket read timeout. + /// + /// Reads will block forever if this is set to `None` or `Some(Duration::ZERO)`. + /// + /// Defaults to 10ms. + pub read_timeout: Option, } impl Default for Options { @@ -159,7 +163,7 @@ impl Default for Options { subprotocols: vec![], // let the OS schedule something else, otherwise busy-loop // TODO: use polling on native instead - delay_blocking: std::time::Duration::from_millis(0), + read_timeout: Some(std::time::Duration::from_millis(10)), } } } diff --git a/ewebsock/src/native_tungstenite.rs b/ewebsock/src/native_tungstenite.rs index 1505dea..3e0b2dc 100644 --- a/ewebsock/src/native_tungstenite.rs +++ b/ewebsock/src/native_tungstenite.rs @@ -1,10 +1,14 @@ //! Native implementation of the WebSocket client using the `tungstenite` crate. +use std::net::TcpStream; use std::{ ops::ControlFlow, sync::mpsc::{Receiver, TryRecvError}, }; +use tungstenite::stream::MaybeTlsStream; +use tungstenite::WebSocket; + use crate::tungstenite_common::into_requester; use crate::{EventHandler, Options, Result, WsEvent, WsMessage}; @@ -70,13 +74,13 @@ pub(crate) fn ws_receive_impl(url: String, options: Options, on_event: EventHand /// # Errors /// All errors are returned to the caller, and NOT reported via `on_event`. pub fn ws_receiver_blocking(url: &str, options: Options, on_event: &EventHandler) -> Result<()> { - let delay = options.delay_blocking; let uri: tungstenite::http::Uri = url .parse() .map_err(|err| format!("Failed to parse URL {url:?}: {err}"))?; let config = tungstenite::protocol::WebSocketConfig::from(options.clone()); let max_redirects = 3; // tungstenite default + let read_timeout = options.read_timeout; let (mut socket, response) = match tungstenite::client::connect_with_config( into_requester(uri, options), Some(config), @@ -88,6 +92,8 @@ pub fn ws_receiver_blocking(url: &str, options: Options, on_event: &EventHandler } }; + set_read_timeout(&mut socket, read_timeout)?; + log::debug!("WebSocket HTTP response code: {}", response.status()); log::trace!( "WebSocket response contains the following headers: {:?}", @@ -103,31 +109,7 @@ pub fn ws_receiver_blocking(url: &str, options: Options, on_event: &EventHandler } loop { - let control = match socket.read() { - Ok(incoming_msg) => match incoming_msg { - tungstenite::protocol::Message::Text(text) => { - on_event(WsEvent::Message(WsMessage::Text(text))) - } - tungstenite::protocol::Message::Binary(data) => { - on_event(WsEvent::Message(WsMessage::Binary(data))) - } - tungstenite::protocol::Message::Ping(data) => { - on_event(WsEvent::Message(WsMessage::Ping(data))) - } - tungstenite::protocol::Message::Pong(data) => { - on_event(WsEvent::Message(WsMessage::Pong(data))) - } - tungstenite::protocol::Message::Close(close) => { - on_event(WsEvent::Closed); - log::debug!("WebSocket close received: {close:?}"); - return Ok(()); - } - tungstenite::protocol::Message::Frame(_) => ControlFlow::Continue(()), - }, - Err(err) => { - return Err(format!("read: {err}")); - } - }; + let control = read_from_socket(&mut socket, on_event)?; if control.is_break() { log::trace!("Closing connection due to Break"); @@ -136,12 +118,7 @@ pub fn ws_receiver_blocking(url: &str, options: Options, on_event: &EventHandler .map_err(|err| format!("Failed to close connection: {err}")); } - // without the check we wouldn't yield at all on some platforms - if delay == std::time::Duration::ZERO { - std::thread::yield_now(); - } else { - std::thread::sleep(delay); - } + std::thread::yield_now(); } } @@ -178,12 +155,13 @@ pub fn ws_connect_blocking( on_event: &EventHandler, rx: &Receiver, ) -> Result<()> { - let delay = options.delay_blocking; let config = tungstenite::protocol::WebSocketConfig::from(options.clone()); let max_redirects = 3; // tungstenite default let uri: tungstenite::http::Uri = url .parse() .map_err(|err| format!("Failed to parse URL {url:?}: {err}"))?; + + let read_timeout = options.read_timeout; let (mut socket, response) = match tungstenite::client::connect_with_config( into_requester(uri, options), Some(config), @@ -195,6 +173,8 @@ pub fn ws_connect_blocking( } }; + set_read_timeout(&mut socket, read_timeout)?; + log::debug!("WebSocket HTTP response code: {}", response.status()); log::trace!( "WebSocket response contains the following headers: {:?}", @@ -209,26 +189,9 @@ pub fn ws_connect_blocking( .map_err(|err| format!("Failed to close connection: {err}")); } - match socket.get_mut() { - tungstenite::stream::MaybeTlsStream::Plain(stream) => stream.set_nonblocking(true), - - // tungstenite::stream::MaybeTlsStream::NativeTls(stream) => { - // stream.get_mut().set_nonblocking(true) - // } - #[cfg(feature = "tls")] - tungstenite::stream::MaybeTlsStream::Rustls(stream) => { - stream.get_mut().set_nonblocking(true) - } - _ => return Err(format!("Unknown tungstenite stream {:?}", socket.get_mut())), - } - .map_err(|err| format!("Failed to make WebSocket non-blocking: {err}"))?; - loop { - let mut did_work = false; - match rx.try_recv() { Ok(outgoing_message) => { - did_work = true; let outgoing_message = match outgoing_message { WsMessage::Text(text) => tungstenite::protocol::Message::Text(text), WsMessage::Binary(data) => tungstenite::protocol::Message::Binary(data), @@ -251,39 +214,7 @@ pub fn ws_connect_blocking( Err(TryRecvError::Empty) => {} }; - let control = match socket.read() { - Ok(incoming_msg) => { - did_work = true; - match incoming_msg { - tungstenite::protocol::Message::Text(text) => { - on_event(WsEvent::Message(WsMessage::Text(text))) - } - tungstenite::protocol::Message::Binary(data) => { - on_event(WsEvent::Message(WsMessage::Binary(data))) - } - tungstenite::protocol::Message::Ping(data) => { - on_event(WsEvent::Message(WsMessage::Ping(data))) - } - tungstenite::protocol::Message::Pong(data) => { - on_event(WsEvent::Message(WsMessage::Pong(data))) - } - tungstenite::protocol::Message::Close(close) => { - on_event(WsEvent::Closed); - log::debug!("Close received: {close:?}"); - return Ok(()); - } - tungstenite::protocol::Message::Frame(_) => ControlFlow::Continue(()), - } - } - Err(tungstenite::Error::Io(io_err)) - if io_err.kind() == std::io::ErrorKind::WouldBlock => - { - ControlFlow::Continue(()) // Ignore - } - Err(err) => { - return Err(format!("read: {err}")); - } - }; + let control = read_from_socket(&mut socket, on_event)?; if control.is_break() { log::trace!("Closing connection due to Break"); @@ -292,15 +223,75 @@ pub fn ws_connect_blocking( .map_err(|err| format!("Failed to close connection: {err}")); } - if !did_work { - // without the check we wouldn't yield at all on some platforms - if delay == std::time::Duration::ZERO { - std::thread::yield_now(); - } else { - std::thread::sleep(delay); + std::thread::yield_now(); + } +} + +fn read_from_socket( + socket: &mut WebSocket>, + on_event: &EventHandler, +) -> Result> { + let control = match socket.read() { + Ok(incoming_msg) => match incoming_msg { + tungstenite::protocol::Message::Text(text) => { + on_event(WsEvent::Message(WsMessage::Text(text))) + } + tungstenite::protocol::Message::Binary(data) => { + on_event(WsEvent::Message(WsMessage::Binary(data))) } + tungstenite::protocol::Message::Ping(data) => { + on_event(WsEvent::Message(WsMessage::Ping(data))) + } + tungstenite::protocol::Message::Pong(data) => { + on_event(WsEvent::Message(WsMessage::Pong(data))) + } + tungstenite::protocol::Message::Close(close) => { + on_event(WsEvent::Closed); + log::debug!("WebSocket close received: {close:?}"); + ControlFlow::Break(()) + } + tungstenite::protocol::Message::Frame(_) => ControlFlow::Continue(()), + }, + // If we get `WouldBlock`, then the read timed out. + // Windows may emit `TimedOut` instead. + Err(tungstenite::Error::Io(io_err)) + if io_err.kind() == std::io::ErrorKind::WouldBlock + || io_err.kind() == std::io::ErrorKind::TimedOut => + { + ControlFlow::Continue(()) // Ignore } + Err(err) => { + return Err(format!("read: {err}")); + } + }; + + Ok(control) +} + +fn set_read_timeout( + s: &mut WebSocket>, + value: Option, +) -> Result<()> { + // zero timeout is the same as no timeout + if value.is_none() || value.is_some_and(|value| value.is_zero()) { + return Ok(()); } + + match s.get_mut() { + MaybeTlsStream::Plain(s) => { + s.set_read_timeout(value) + .map_err(|err| format!("failed to set read timeout: {err}"))?; + } + #[cfg(feature = "tls")] + MaybeTlsStream::Rustls(s) => { + s.get_mut() + .set_read_timeout(value) + .map_err(|err| format!("failed to set read timeout: {err}"))?; + } + _ => {} + }; + + Ok(()) } #[test]