Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix blocking receiver sleeping after every read #48

Merged
merged 4 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions ewebsock/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,12 @@ pub struct Options {
/// Currently only supported on native.
pub subprotocols: Vec<String>,

/// Delay blocking in ms - default 10ms
pub delay_blocking: std::time::Duration,
/// Socket read timeout.
///
/// Reads will block forever if this is set to `None` or `Some(Duration::ZERO)`.
///
/// Defaults to 10ms.
pub read_timeout: Option<std::time::Duration>,
}

impl Default for Options {
Expand All @@ -157,7 +161,9 @@ impl Default for Options {
max_incoming_frame_size: 64 * 1024 * 1024,
additional_headers: vec![],
subprotocols: vec![],
delay_blocking: std::time::Duration::from_millis(10), // default value 10ms,
// let the OS schedule something else, otherwise busy-loop
// TODO: use polling on native instead
read_timeout: Some(std::time::Duration::from_millis(10)),
}
}
}
Expand Down
160 changes: 81 additions & 79 deletions ewebsock/src/native_tungstenite.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
//! Native implementation of the WebSocket client using the `tungstenite` crate.

use std::net::TcpStream;
use std::{
ops::ControlFlow,
sync::mpsc::{Receiver, TryRecvError},
};

use tungstenite::stream::MaybeTlsStream;
use tungstenite::WebSocket;

use crate::tungstenite_common::into_requester;
use crate::{EventHandler, Options, Result, WsEvent, WsMessage};

Expand Down Expand Up @@ -76,6 +80,7 @@ pub fn ws_receiver_blocking(url: &str, options: Options, on_event: &EventHandler
let config = tungstenite::protocol::WebSocketConfig::from(options.clone());
let max_redirects = 3; // tungstenite default

let read_timeout = options.read_timeout;
let (mut socket, response) = match tungstenite::client::connect_with_config(
into_requester(uri, options),
Some(config),
Expand All @@ -87,6 +92,8 @@ pub fn ws_receiver_blocking(url: &str, options: Options, on_event: &EventHandler
}
};

set_read_timeout(&mut socket, read_timeout)?;

log::debug!("WebSocket HTTP response code: {}", response.status());
log::trace!(
"WebSocket response contains the following headers: {:?}",
Expand All @@ -102,31 +109,7 @@ pub fn ws_receiver_blocking(url: &str, options: Options, on_event: &EventHandler
}

loop {
let control = match socket.read() {
Ok(incoming_msg) => match incoming_msg {
tungstenite::protocol::Message::Text(text) => {
on_event(WsEvent::Message(WsMessage::Text(text)))
}
tungstenite::protocol::Message::Binary(data) => {
on_event(WsEvent::Message(WsMessage::Binary(data)))
}
tungstenite::protocol::Message::Ping(data) => {
on_event(WsEvent::Message(WsMessage::Ping(data)))
}
tungstenite::protocol::Message::Pong(data) => {
on_event(WsEvent::Message(WsMessage::Pong(data)))
}
tungstenite::protocol::Message::Close(close) => {
on_event(WsEvent::Closed);
log::debug!("WebSocket close received: {close:?}");
return Ok(());
}
tungstenite::protocol::Message::Frame(_) => ControlFlow::Continue(()),
},
Err(err) => {
return Err(format!("read: {err}"));
}
};
let control = read_from_socket(&mut socket, on_event)?;

if control.is_break() {
log::trace!("Closing connection due to Break");
Expand All @@ -135,7 +118,7 @@ pub fn ws_receiver_blocking(url: &str, options: Options, on_event: &EventHandler
.map_err(|err| format!("Failed to close connection: {err}"));
}

std::thread::sleep(std::time::Duration::from_millis(10));
std::thread::yield_now();
}
}

Expand Down Expand Up @@ -172,12 +155,13 @@ pub fn ws_connect_blocking(
on_event: &EventHandler,
rx: &Receiver<WsMessage>,
) -> Result<()> {
let delay = options.delay_blocking;
let config = tungstenite::protocol::WebSocketConfig::from(options.clone());
let max_redirects = 3; // tungstenite default
let uri: tungstenite::http::Uri = url
.parse()
.map_err(|err| format!("Failed to parse URL {url:?}: {err}"))?;

let read_timeout = options.read_timeout;
let (mut socket, response) = match tungstenite::client::connect_with_config(
into_requester(uri, options),
Some(config),
Expand All @@ -189,6 +173,8 @@ pub fn ws_connect_blocking(
}
};

set_read_timeout(&mut socket, read_timeout)?;

log::debug!("WebSocket HTTP response code: {}", response.status());
log::trace!(
"WebSocket response contains the following headers: {:?}",
Expand All @@ -203,26 +189,9 @@ pub fn ws_connect_blocking(
.map_err(|err| format!("Failed to close connection: {err}"));
}

match socket.get_mut() {
tungstenite::stream::MaybeTlsStream::Plain(stream) => stream.set_nonblocking(true),

// tungstenite::stream::MaybeTlsStream::NativeTls(stream) => {
// stream.get_mut().set_nonblocking(true)
// }
#[cfg(feature = "tls")]
tungstenite::stream::MaybeTlsStream::Rustls(stream) => {
stream.get_mut().set_nonblocking(true)
}
_ => return Err(format!("Unknown tungstenite stream {:?}", socket.get_mut())),
}
.map_err(|err| format!("Failed to make WebSocket non-blocking: {err}"))?;

loop {
let mut did_work = false;

match rx.try_recv() {
Ok(outgoing_message) => {
did_work = true;
let outgoing_message = match outgoing_message {
WsMessage::Text(text) => tungstenite::protocol::Message::Text(text),
WsMessage::Binary(data) => tungstenite::protocol::Message::Binary(data),
Expand All @@ -245,39 +214,7 @@ pub fn ws_connect_blocking(
Err(TryRecvError::Empty) => {}
};

let control = match socket.read() {
Ok(incoming_msg) => {
did_work = true;
match incoming_msg {
tungstenite::protocol::Message::Text(text) => {
on_event(WsEvent::Message(WsMessage::Text(text)))
}
tungstenite::protocol::Message::Binary(data) => {
on_event(WsEvent::Message(WsMessage::Binary(data)))
}
tungstenite::protocol::Message::Ping(data) => {
on_event(WsEvent::Message(WsMessage::Ping(data)))
}
tungstenite::protocol::Message::Pong(data) => {
on_event(WsEvent::Message(WsMessage::Pong(data)))
}
tungstenite::protocol::Message::Close(close) => {
on_event(WsEvent::Closed);
log::debug!("Close received: {close:?}");
return Ok(());
}
tungstenite::protocol::Message::Frame(_) => ControlFlow::Continue(()),
}
}
Err(tungstenite::Error::Io(io_err))
if io_err.kind() == std::io::ErrorKind::WouldBlock =>
{
ControlFlow::Continue(()) // Ignore
}
Err(err) => {
return Err(format!("read: {err}"));
}
};
let control = read_from_socket(&mut socket, on_event)?;

if control.is_break() {
log::trace!("Closing connection due to Break");
Expand All @@ -286,10 +223,75 @@ pub fn ws_connect_blocking(
.map_err(|err| format!("Failed to close connection: {err}"));
}

if !did_work {
std::thread::sleep(delay);
std::thread::yield_now();
}
}

fn read_from_socket(
socket: &mut WebSocket<MaybeTlsStream<TcpStream>>,
on_event: &EventHandler,
) -> Result<ControlFlow<()>> {
let control = match socket.read() {
Ok(incoming_msg) => match incoming_msg {
tungstenite::protocol::Message::Text(text) => {
on_event(WsEvent::Message(WsMessage::Text(text)))
}
tungstenite::protocol::Message::Binary(data) => {
on_event(WsEvent::Message(WsMessage::Binary(data)))
}
tungstenite::protocol::Message::Ping(data) => {
on_event(WsEvent::Message(WsMessage::Ping(data)))
}
tungstenite::protocol::Message::Pong(data) => {
on_event(WsEvent::Message(WsMessage::Pong(data)))
}
tungstenite::protocol::Message::Close(close) => {
on_event(WsEvent::Closed);
log::debug!("WebSocket close received: {close:?}");
ControlFlow::Break(())
}
tungstenite::protocol::Message::Frame(_) => ControlFlow::Continue(()),
},
// If we get `WouldBlock`, then the read timed out.
// Windows may emit `TimedOut` instead.
Err(tungstenite::Error::Io(io_err))
if io_err.kind() == std::io::ErrorKind::WouldBlock
|| io_err.kind() == std::io::ErrorKind::TimedOut =>
{
ControlFlow::Continue(()) // Ignore
}
Err(err) => {
return Err(format!("read: {err}"));
}
};

Ok(control)
}

fn set_read_timeout(
s: &mut WebSocket<MaybeTlsStream<TcpStream>>,
value: Option<std::time::Duration>,
) -> Result<()> {
// zero timeout is the same as no timeout
if value.is_none() || value.is_some_and(|value| value.is_zero()) {
return Ok(());
}

match s.get_mut() {
MaybeTlsStream::Plain(s) => {
s.set_read_timeout(value)
.map_err(|err| format!("failed to set read timeout: {err}"))?;
}
#[cfg(feature = "tls")]
MaybeTlsStream::Rustls(s) => {
s.get_mut()
.set_read_timeout(value)
.map_err(|err| format!("failed to set read timeout: {err}"))?;
}
_ => {}
};

Ok(())
}

#[test]
Expand Down
2 changes: 2 additions & 0 deletions example_app/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
//! Example application.

mod app;
pub use app::ExampleApp;

Expand Down
Loading