Skip to content

Commit

Permalink
Bringing splitting back
Browse files Browse the repository at this point in the history
  • Loading branch information
dbcfd committed Oct 19, 2019
1 parent 46ac847 commit 1cbac51
Show file tree
Hide file tree
Showing 5 changed files with 208 additions and 20 deletions.
83 changes: 83 additions & 0 deletions examples/split-client.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
//! A simple example of hooking up stdin/stdout to a WebSocket stream.
//!
//! This example will connect to a server specified in the argument list and
//! then forward all data read on stdin to the server, printing out all data
//! received on stdout.
//!
//! Note that this is not currently optimized for performance, especially around
//! buffer management. Rather it's intended to show an example of working with a
//! client.
//!
//! You can use this example together with the `server` example.
use std::env;
use std::io::{self, Write};

use futures::{SinkExt, StreamExt};
use log::*;
use tungstenite::protocol::Message;

use tokio::io::AsyncReadExt;
use tokio_tungstenite::connect_async;

#[tokio::main]
async fn main() {
let _ = env_logger::try_init();

// Specify the server address to which the client will be connecting.
let connect_addr = env::args()
.nth(1)
.unwrap_or_else(|| panic!("this program requires at least one argument"));

let url = url::Url::parse(&connect_addr).unwrap();

// Right now Tokio doesn't support a handle to stdin running on the event
// loop, so we farm out that work to a separate thread. This thread will
// read data from stdin and then send it to the event loop over a standard
// futures channel.
let (stdin_tx, mut stdin_rx) = futures::channel::mpsc::unbounded();
tokio::spawn(read_stdin(stdin_tx));

// After the TCP connection has been established, we set up our client to
// start forwarding data.
//
// First we do a WebSocket handshake on a TCP stream, i.e. do the upgrade
// request.
//
// Half of the work we're going to do is to take all data we receive on
// stdin (`stdin_rx`) and send that along the WebSocket stream (`sink`).
// The second half is to take all the data we receive (`stream`) and then
// write that to stdout. Currently we just write to stdout in a synchronous
// fashion.
//
// Finally we set the client to terminate once either half of this work
// finishes. If we don't have any more data to read or we won't receive any
// more work from the remote then we can exit.
let mut stdout = io::stdout();
let (ws_stream, _) = connect_async(url).await.expect("Failed to connect");
let (mut ws_tx, mut ws_rx) = ws_stream.split();
info!("WebSocket handshake has been successfully completed");

while let Some(msg) = stdin_rx.next().await {
ws_tx.send(msg).await.expect("Failed to send request");
if let Some(msg) = ws_rx.next().await {
let msg = msg.expect("Failed to get response");
stdout.write_all(&msg.into_data()).unwrap();
}
}
}

// Our helper method which will read data from stdin and send it along the
// sender provided.
async fn read_stdin(tx: futures::channel::mpsc::UnboundedSender<Message>) {
let mut stdin = tokio::io::stdin();
loop {
let mut buf = vec![0; 1024];
let n = match stdin.read(&mut buf).await {
Err(_) | Ok(0) => break,
Ok(n) => n,
};
buf.truncate(n);
tx.unbounded_send(Message::binary(buf)).unwrap();
}
}
20 changes: 12 additions & 8 deletions src/compat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,16 @@ use tokio_io::{AsyncRead, AsyncWrite};
use tungstenite::{Error as WsError, WebSocket};

pub(crate) trait HasContext {
fn set_context(&mut self, context: *mut ());
fn set_context(&mut self, context: (bool, *mut ()));
}
#[derive(Debug)]
pub struct AllowStd<S> {
pub(crate) inner: S,
pub(crate) context: *mut (),
pub(crate) context: (bool, *mut ()),
}

impl<S> HasContext for AllowStd<S> {
fn set_context(&mut self, context: *mut ()) {
fn set_context(&mut self, context: (bool, *mut ())) {
self.context = context;
}
}
Expand All @@ -26,7 +26,7 @@ pub(crate) struct Guard<'a, S>(pub(crate) &'a mut WebSocket<AllowStd<S>>);
impl<S> Drop for Guard<'_, S> {
fn drop(&mut self) {
trace!("{}:{} Guard.drop", file!(), line!());
(self.0).get_mut().context = std::ptr::null_mut();
(self.0).get_mut().context = (true, std::ptr::null_mut());
}
}

Expand All @@ -38,14 +38,18 @@ impl<S> AllowStd<S>
where
S: Unpin,
{
fn with_context<F, R>(&mut self, f: F) -> R
fn with_context<F, R>(&mut self, f: F) -> Poll<std::io::Result<R>>
where
F: FnOnce(&mut Context<'_>, Pin<&mut S>) -> R,
F: FnOnce(&mut Context<'_>, Pin<&mut S>) -> Poll<std::io::Result<R>>,
{
trace!("{}:{} AllowStd.with_context", file!(), line!());
unsafe {
assert!(!self.context.is_null());
let waker = &mut *(self.context as *mut _);
if !self.context.0 {
//was called by start_send without context
return Poll::Pending
}
assert!(!self.context.1.is_null());
let waker = &mut *(self.context.1 as *mut _);
f(waker, Pin::new(&mut self.inner))
}
}
Expand Down
10 changes: 5 additions & 5 deletions src/handshake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ where
trace!("Setting context when skipping handshake");
let stream = AllowStd {
inner: inner.stream,
context: ctx as *mut _ as *mut (),
context: (true, ctx as *mut _ as *mut ()),
};

Poll::Ready((inner.f)(stream))
Expand Down Expand Up @@ -137,14 +137,14 @@ where
trace!("Setting ctx when starting handshake");
let stream = AllowStd {
inner: inner.stream,
context: ctx as *mut _ as *mut (),
context: (true, ctx as *mut _ as *mut ()),
};

match (inner.f)(stream) {
Ok(r) => Poll::Ready(Ok(StartedHandshake::Done(r))),
Err(Error::Interrupted(mut mid)) => {
let machine = mid.get_mut();
machine.get_mut().set_context(std::ptr::null_mut());
machine.get_mut().set_context((true, std::ptr::null_mut()));
Poll::Ready(Ok(StartedHandshake::Mid(mid)))
}
Err(Error::Failure(e)) => Poll::Ready(Err(Error::Failure(e))),
Expand All @@ -165,14 +165,14 @@ where

let machine = s.get_mut();
trace!("Setting context in handshake");
machine.get_mut().set_context(cx as *mut _ as *mut ());
machine.get_mut().set_context((true, cx as *mut _ as *mut ()));

match s.handshake() {
Ok(stream) => Poll::Ready(Ok(stream)),
Err(Error::Failure(e)) => Poll::Ready(Err(Error::Failure(e))),
Err(Error::Interrupted(mut mid)) => {
let machine = mid.get_mut();
machine.get_mut().set_context(std::ptr::null_mut());
machine.get_mut().set_context((true, std::ptr::null_mut()));
*this.0 = Some(mid);
Poll::Pending
}
Expand Down
57 changes: 51 additions & 6 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ pub mod stream;
use std::io::{Read, Write};

use compat::{cvt, AllowStd};
use futures::Stream;
use futures::{Stream, Sink};
use log::*;
use pin_project::pin_project;
use std::future::Future;
Expand Down Expand Up @@ -214,14 +214,17 @@ impl<S> WebSocketStream<S> {
WebSocketStream { inner: ws }
}

fn with_context<F, R>(&mut self, ctx: &mut Context<'_>, f: F) -> R
fn with_context<F, R>(&mut self, ctx: Option<&mut Context<'_>>, f: F) -> R
where
S: Unpin,
F: FnOnce(&mut WebSocket<AllowStd<S>>) -> R,
AllowStd<S>: Read + Write,
{
trace!("{}:{} WebSocketStream.with_context", file!(), line!());
self.inner.get_mut().context = ctx as *mut _ as *mut ();
self.inner.get_mut().context = match ctx {
None => (false, std::ptr::null_mut()),
Some(cx) => (true, cx as *mut _ as *mut ()),
};
let mut g = compat::Guard(&mut self.inner);
f(&mut (g.0))
}
Expand Down Expand Up @@ -276,7 +279,7 @@ where

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
trace!("{}:{} Stream.poll_next", file!(), line!());
match futures::ready!(self.with_context(cx, |s| {
match futures::ready!(self.with_context(Some(cx), |s| {
trace!(
"{}:{} Stream.with_context poll_next -> read_message()",
file!(),
Expand All @@ -291,6 +294,48 @@ where
}
}

impl<T> Sink<Message> for WebSocketStream<T>
where
T: AsyncRead + AsyncWrite + Unpin,
AllowStd<T>: Read + Write,
{
type Error = WsError;

fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
(*self).with_context(Some(cx), |s| cvt(s.write_pending()))
}

fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
match (*self).with_context(None, |s| s.write_message(item)) {
Ok(()) => Ok(()),
Err(::tungstenite::Error::Io(ref err)) if err.kind() == std::io::ErrorKind::WouldBlock => {
// the message was accepted and queued
// isn't an error.
Ok(())
}
Err(e) => {
debug!("websocket start_send error: {}", e);
Err(e)
}
}
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
(*self).with_context(Some(cx), |s| cvt(s.write_pending()))
}

fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match (*self).with_context(Some(cx), |s| s.close(None)) {
Ok(()) => Poll::Ready(Ok(())),
Err(::tungstenite::Error::ConnectionClosed) => Poll::Ready(Ok(())),
Err(err) => {
debug!("websocket close error: {}", err);
Poll::Ready(Err(err))
}
}
}
}

#[pin_project]
struct SendFuture<'a, T> {
stream: &'a mut WebSocketStream<T>,
Expand All @@ -307,7 +352,7 @@ where
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
let message = this.message.take().expect("Cannot poll twice");
Poll::Ready(this.stream.with_context(cx, |s| s.write_message(message)))
Poll::Ready(this.stream.with_context(Some(cx), |s| s.write_message(message)))
}
}

Expand All @@ -327,7 +372,7 @@ where
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
let message = this.message.take().expect("Cannot poll twice");
Poll::Ready(this.stream.with_context(cx, |s| s.close(message)))
Poll::Ready(this.stream.with_context(Some(cx), |s| s.close(message)))
}
}

Expand Down
58 changes: 57 additions & 1 deletion tests/communication.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use futures::StreamExt;
use futures::{SinkExt, StreamExt};
use log::*;
use std::net::ToSocketAddrs;
use tokio::io::{AsyncRead, AsyncWrite};
Expand Down Expand Up @@ -77,3 +77,59 @@ async fn communication() {
let messages = msg_rx.await.expect("Failed to receive messages");
assert_eq!(messages.len(), 10);
}

#[tokio::test]
async fn split_communication() {
let _ = env_logger::try_init();

let (con_tx, con_rx) = futures::channel::oneshot::channel();
let (msg_tx, msg_rx) = futures::channel::oneshot::channel();

let f = async move {
let address = "0.0.0.0:12346"
.to_socket_addrs()
.expect("Not a valid address")
.next()
.expect("No address resolved");
let listener = TcpListener::bind(&address).await.unwrap();
let mut connections = listener.incoming();
info!("Server ready");
con_tx.send(()).unwrap();
info!("Waiting on next connection");
let connection = connections.next().await.expect("No connections to accept");
let connection = connection.expect("Failed to accept connection");
let stream = accept_async(connection).await;
let stream = stream.expect("Failed to handshake with connection");
run_connection(stream, msg_tx).await;
};

tokio::spawn(f);

info!("Waiting for server to be ready");

con_rx.await.expect("Server not ready");
let address = "0.0.0.0:12346"
.to_socket_addrs()
.expect("Not a valid address")
.next()
.expect("No address resolved");
let tcp = TcpStream::connect(&address)
.await
.expect("Failed to connect");
let url = url::Url::parse("ws://localhost:12345/").unwrap();
let (stream, _) = client_async(url, tcp)
.await
.expect("Client failed to connect");
let (mut tx, _rx) = stream.split();

for i in 1..10 {
info!("Sending message");
tx.send(Message::Text(format!("{}", i))).await.expect("Failed to send message");
}

tx.close().await.expect("Failed to close");

info!("Waiting for response messages");
let messages = msg_rx.await.expect("Failed to receive messages");
assert_eq!(messages.len(), 10);
}

0 comments on commit 1cbac51

Please sign in to comment.