From 1cbac51d14174e62e159b162d20692f128fec916 Mon Sep 17 00:00:00 2001 From: Danny Browning Date: Sat, 19 Oct 2019 11:15:12 -0600 Subject: [PATCH] Bringing splitting back --- examples/split-client.rs | 83 ++++++++++++++++++++++++++++++++++++++++ src/compat.rs | 20 ++++++---- src/handshake.rs | 10 ++--- src/lib.rs | 57 ++++++++++++++++++++++++--- tests/communication.rs | 58 +++++++++++++++++++++++++++- 5 files changed, 208 insertions(+), 20 deletions(-) create mode 100644 examples/split-client.rs diff --git a/examples/split-client.rs b/examples/split-client.rs new file mode 100644 index 00000000..a8f68009 --- /dev/null +++ b/examples/split-client.rs @@ -0,0 +1,83 @@ +//! A simple example of hooking up stdin/stdout to a WebSocket stream. +//! +//! This example will connect to a server specified in the argument list and +//! then forward all data read on stdin to the server, printing out all data +//! received on stdout. +//! +//! Note that this is not currently optimized for performance, especially around +//! buffer management. Rather it's intended to show an example of working with a +//! client. +//! +//! You can use this example together with the `server` example. + +use std::env; +use std::io::{self, Write}; + +use futures::{SinkExt, StreamExt}; +use log::*; +use tungstenite::protocol::Message; + +use tokio::io::AsyncReadExt; +use tokio_tungstenite::connect_async; + +#[tokio::main] +async fn main() { + let _ = env_logger::try_init(); + + // Specify the server address to which the client will be connecting. + let connect_addr = env::args() + .nth(1) + .unwrap_or_else(|| panic!("this program requires at least one argument")); + + let url = url::Url::parse(&connect_addr).unwrap(); + + // Right now Tokio doesn't support a handle to stdin running on the event + // loop, so we farm out that work to a separate thread. This thread will + // read data from stdin and then send it to the event loop over a standard + // futures channel. + let (stdin_tx, mut stdin_rx) = futures::channel::mpsc::unbounded(); + tokio::spawn(read_stdin(stdin_tx)); + + // After the TCP connection has been established, we set up our client to + // start forwarding data. + // + // First we do a WebSocket handshake on a TCP stream, i.e. do the upgrade + // request. + // + // Half of the work we're going to do is to take all data we receive on + // stdin (`stdin_rx`) and send that along the WebSocket stream (`sink`). + // The second half is to take all the data we receive (`stream`) and then + // write that to stdout. Currently we just write to stdout in a synchronous + // fashion. + // + // Finally we set the client to terminate once either half of this work + // finishes. If we don't have any more data to read or we won't receive any + // more work from the remote then we can exit. + let mut stdout = io::stdout(); + let (ws_stream, _) = connect_async(url).await.expect("Failed to connect"); + let (mut ws_tx, mut ws_rx) = ws_stream.split(); + info!("WebSocket handshake has been successfully completed"); + + while let Some(msg) = stdin_rx.next().await { + ws_tx.send(msg).await.expect("Failed to send request"); + if let Some(msg) = ws_rx.next().await { + let msg = msg.expect("Failed to get response"); + stdout.write_all(&msg.into_data()).unwrap(); + } + } +} + +// Our helper method which will read data from stdin and send it along the +// sender provided. +async fn read_stdin(tx: futures::channel::mpsc::UnboundedSender) { + let mut stdin = tokio::io::stdin(); + loop { + let mut buf = vec![0; 1024]; + let n = match stdin.read(&mut buf).await { + Err(_) | Ok(0) => break, + Ok(n) => n, + }; + buf.truncate(n); + tx.unbounded_send(Message::binary(buf)).unwrap(); + } +} diff --git a/src/compat.rs b/src/compat.rs index 218a338e..1ab90d90 100644 --- a/src/compat.rs +++ b/src/compat.rs @@ -7,16 +7,16 @@ use tokio_io::{AsyncRead, AsyncWrite}; use tungstenite::{Error as WsError, WebSocket}; pub(crate) trait HasContext { - fn set_context(&mut self, context: *mut ()); + fn set_context(&mut self, context: (bool, *mut ())); } #[derive(Debug)] pub struct AllowStd { pub(crate) inner: S, - pub(crate) context: *mut (), + pub(crate) context: (bool, *mut ()), } impl HasContext for AllowStd { - fn set_context(&mut self, context: *mut ()) { + fn set_context(&mut self, context: (bool, *mut ())) { self.context = context; } } @@ -26,7 +26,7 @@ pub(crate) struct Guard<'a, S>(pub(crate) &'a mut WebSocket>); impl Drop for Guard<'_, S> { fn drop(&mut self) { trace!("{}:{} Guard.drop", file!(), line!()); - (self.0).get_mut().context = std::ptr::null_mut(); + (self.0).get_mut().context = (true, std::ptr::null_mut()); } } @@ -38,14 +38,18 @@ impl AllowStd where S: Unpin, { - fn with_context(&mut self, f: F) -> R + fn with_context(&mut self, f: F) -> Poll> where - F: FnOnce(&mut Context<'_>, Pin<&mut S>) -> R, + F: FnOnce(&mut Context<'_>, Pin<&mut S>) -> Poll>, { trace!("{}:{} AllowStd.with_context", file!(), line!()); unsafe { - assert!(!self.context.is_null()); - let waker = &mut *(self.context as *mut _); + if !self.context.0 { + //was called by start_send without context + return Poll::Pending + } + assert!(!self.context.1.is_null()); + let waker = &mut *(self.context.1 as *mut _); f(waker, Pin::new(&mut self.inner)) } } diff --git a/src/handshake.rs b/src/handshake.rs index 9c4738cd..33dac86b 100644 --- a/src/handshake.rs +++ b/src/handshake.rs @@ -47,7 +47,7 @@ where trace!("Setting context when skipping handshake"); let stream = AllowStd { inner: inner.stream, - context: ctx as *mut _ as *mut (), + context: (true, ctx as *mut _ as *mut ()), }; Poll::Ready((inner.f)(stream)) @@ -137,14 +137,14 @@ where trace!("Setting ctx when starting handshake"); let stream = AllowStd { inner: inner.stream, - context: ctx as *mut _ as *mut (), + context: (true, ctx as *mut _ as *mut ()), }; match (inner.f)(stream) { Ok(r) => Poll::Ready(Ok(StartedHandshake::Done(r))), Err(Error::Interrupted(mut mid)) => { let machine = mid.get_mut(); - machine.get_mut().set_context(std::ptr::null_mut()); + machine.get_mut().set_context((true, std::ptr::null_mut())); Poll::Ready(Ok(StartedHandshake::Mid(mid))) } Err(Error::Failure(e)) => Poll::Ready(Err(Error::Failure(e))), @@ -165,14 +165,14 @@ where let machine = s.get_mut(); trace!("Setting context in handshake"); - machine.get_mut().set_context(cx as *mut _ as *mut ()); + machine.get_mut().set_context((true, cx as *mut _ as *mut ())); match s.handshake() { Ok(stream) => Poll::Ready(Ok(stream)), Err(Error::Failure(e)) => Poll::Ready(Err(Error::Failure(e))), Err(Error::Interrupted(mut mid)) => { let machine = mid.get_mut(); - machine.get_mut().set_context(std::ptr::null_mut()); + machine.get_mut().set_context((true, std::ptr::null_mut())); *this.0 = Some(mid); Poll::Pending } diff --git a/src/lib.rs b/src/lib.rs index f7cb1935..9955c612 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -28,7 +28,7 @@ pub mod stream; use std::io::{Read, Write}; use compat::{cvt, AllowStd}; -use futures::Stream; +use futures::{Stream, Sink}; use log::*; use pin_project::pin_project; use std::future::Future; @@ -214,14 +214,17 @@ impl WebSocketStream { WebSocketStream { inner: ws } } - fn with_context(&mut self, ctx: &mut Context<'_>, f: F) -> R + fn with_context(&mut self, ctx: Option<&mut Context<'_>>, f: F) -> R where S: Unpin, F: FnOnce(&mut WebSocket>) -> R, AllowStd: Read + Write, { trace!("{}:{} WebSocketStream.with_context", file!(), line!()); - self.inner.get_mut().context = ctx as *mut _ as *mut (); + self.inner.get_mut().context = match ctx { + None => (false, std::ptr::null_mut()), + Some(cx) => (true, cx as *mut _ as *mut ()), + }; let mut g = compat::Guard(&mut self.inner); f(&mut (g.0)) } @@ -276,7 +279,7 @@ where fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { trace!("{}:{} Stream.poll_next", file!(), line!()); - match futures::ready!(self.with_context(cx, |s| { + match futures::ready!(self.with_context(Some(cx), |s| { trace!( "{}:{} Stream.with_context poll_next -> read_message()", file!(), @@ -291,6 +294,48 @@ where } } +impl Sink for WebSocketStream + where + T: AsyncRead + AsyncWrite + Unpin, + AllowStd: Read + Write, +{ + type Error = WsError; + + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + (*self).with_context(Some(cx), |s| cvt(s.write_pending())) + } + + fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> { + match (*self).with_context(None, |s| s.write_message(item)) { + Ok(()) => Ok(()), + Err(::tungstenite::Error::Io(ref err)) if err.kind() == std::io::ErrorKind::WouldBlock => { + // the message was accepted and queued + // isn't an error. + Ok(()) + } + Err(e) => { + debug!("websocket start_send error: {}", e); + Err(e) + } + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + (*self).with_context(Some(cx), |s| cvt(s.write_pending())) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match (*self).with_context(Some(cx), |s| s.close(None)) { + Ok(()) => Poll::Ready(Ok(())), + Err(::tungstenite::Error::ConnectionClosed) => Poll::Ready(Ok(())), + Err(err) => { + debug!("websocket close error: {}", err); + Poll::Ready(Err(err)) + } + } + } +} + #[pin_project] struct SendFuture<'a, T> { stream: &'a mut WebSocketStream, @@ -307,7 +352,7 @@ where fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); let message = this.message.take().expect("Cannot poll twice"); - Poll::Ready(this.stream.with_context(cx, |s| s.write_message(message))) + Poll::Ready(this.stream.with_context(Some(cx), |s| s.write_message(message))) } } @@ -327,7 +372,7 @@ where fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); let message = this.message.take().expect("Cannot poll twice"); - Poll::Ready(this.stream.with_context(cx, |s| s.close(message))) + Poll::Ready(this.stream.with_context(Some(cx), |s| s.close(message))) } } diff --git a/tests/communication.rs b/tests/communication.rs index e9c3e5b8..61d00122 100644 --- a/tests/communication.rs +++ b/tests/communication.rs @@ -1,4 +1,4 @@ -use futures::StreamExt; +use futures::{SinkExt, StreamExt}; use log::*; use std::net::ToSocketAddrs; use tokio::io::{AsyncRead, AsyncWrite}; @@ -77,3 +77,59 @@ async fn communication() { let messages = msg_rx.await.expect("Failed to receive messages"); assert_eq!(messages.len(), 10); } + +#[tokio::test] +async fn split_communication() { + let _ = env_logger::try_init(); + + let (con_tx, con_rx) = futures::channel::oneshot::channel(); + let (msg_tx, msg_rx) = futures::channel::oneshot::channel(); + + let f = async move { + let address = "0.0.0.0:12346" + .to_socket_addrs() + .expect("Not a valid address") + .next() + .expect("No address resolved"); + let listener = TcpListener::bind(&address).await.unwrap(); + let mut connections = listener.incoming(); + info!("Server ready"); + con_tx.send(()).unwrap(); + info!("Waiting on next connection"); + let connection = connections.next().await.expect("No connections to accept"); + let connection = connection.expect("Failed to accept connection"); + let stream = accept_async(connection).await; + let stream = stream.expect("Failed to handshake with connection"); + run_connection(stream, msg_tx).await; + }; + + tokio::spawn(f); + + info!("Waiting for server to be ready"); + + con_rx.await.expect("Server not ready"); + let address = "0.0.0.0:12346" + .to_socket_addrs() + .expect("Not a valid address") + .next() + .expect("No address resolved"); + let tcp = TcpStream::connect(&address) + .await + .expect("Failed to connect"); + let url = url::Url::parse("ws://localhost:12345/").unwrap(); + let (stream, _) = client_async(url, tcp) + .await + .expect("Client failed to connect"); + let (mut tx, _rx) = stream.split(); + + for i in 1..10 { + info!("Sending message"); + tx.send(Message::Text(format!("{}", i))).await.expect("Failed to send message"); + } + + tx.close().await.expect("Failed to close"); + + info!("Waiting for response messages"); + let messages = msg_rx.await.expect("Failed to receive messages"); + assert_eq!(messages.len(), 10); +}