From 9eb9a8f7698bfea303a7bf5c6771db3740cb6b2a Mon Sep 17 00:00:00 2001 From: Youyuan Wu Date: Sat, 18 May 2024 22:00:51 -0700 Subject: [PATCH] switch some apis to use polling --- Cargo.lock | 87 ++++++++++ crates/libs/msquic/Cargo.toml | 1 + crates/libs/msquic/src/buffer.rs | 59 ++++--- crates/libs/msquic/src/lib.rs | 2 + crates/libs/msquic/src/msh3/mod.rs | 210 ++++++++++++++++++++++++ crates/libs/msquic/src/stream.rs | 197 ++++++++++++++++------- crates/libs/msquic/src/sync.rs | 248 ++++++++++++++++++++++++++++- 7 files changed, 721 insertions(+), 83 deletions(-) create mode 100644 crates/libs/msquic/src/msh3/mod.rs diff --git a/Cargo.lock b/Cargo.lock index 3365a74..32351e5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -128,12 +128,72 @@ version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a47c1c47d2f5964e29c61246e81db715514cd532db6b5116a25ea3c03d6780a2" +[[package]] +name = "fastrand" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9fc0510504f03c51ada170672ac806f1f105a88aa97a5281117e1ddc3368e51a" + +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + +[[package]] +name = "futures-core" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d" + +[[package]] +name = "futures-io" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1" + +[[package]] +name = "futures-task" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004" + +[[package]] +name = "futures-util" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" +dependencies = [ + "futures-core", + "futures-io", + "futures-task", + "memchr", + "pin-project-lite", + "pin-utils", + "slab", +] + [[package]] name = "gimli" version = "0.28.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" +[[package]] +name = "h3" +version = "0.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1c8886b9e6e93e7ed93d9433f3779e8d07e3ff96bc67b977d14c7b20c849411" +dependencies = [ + "bytes", + "fastrand", + "futures-util", + "http", + "pin-project-lite", + "tokio", + "tracing", +] + [[package]] name = "hermit-abi" version = "0.3.9" @@ -146,6 +206,17 @@ version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" +[[package]] +name = "http" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21b9ddb458710bc376481b842f5da65cdf31522de232c1ca8146abce2a358258" +dependencies = [ + "bytes", + "fnv", + "itoa", +] + [[package]] name = "itoa" version = "1.0.11" @@ -212,6 +283,7 @@ version = "0.1.0" dependencies = [ "bytes", "c2", + "h3", "hex", "tokio", "tracing", @@ -282,6 +354,12 @@ version = "0.2.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bda66fc9667c18cb2758a2ac84d1167245054bcf85d5d1aaa6923f45801bdd02" +[[package]] +name = "pin-utils" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" + [[package]] name = "proc-macro2" version = "1.0.80" @@ -404,6 +482,15 @@ dependencies = [ "windows-core", ] +[[package]] +name = "slab" +version = "0.4.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f92a496fb766b417c996b9c5e57daf2f7ad3b0bebe1ccfca4856390e3d3bb67" +dependencies = [ + "autocfg", +] + [[package]] name = "smallvec" version = "1.13.2" diff --git a/crates/libs/msquic/Cargo.toml b/crates/libs/msquic/Cargo.toml index f08d21d..ef8408d 100644 --- a/crates/libs/msquic/Cargo.toml +++ b/crates/libs/msquic/Cargo.toml @@ -12,6 +12,7 @@ edition = "2021" tokio = {version = "1", features = ["sync"]} tracing = { version = "0.1.40", features = ["log"] } bytes = "*" +h3 = "*" [dev-dependencies] # env_logger = "0.10.1" diff --git a/crates/libs/msquic/src/buffer.rs b/crates/libs/msquic/src/buffer.rs index 3fcac4d..2b25d9f 100644 --- a/crates/libs/msquic/src/buffer.rs +++ b/crates/libs/msquic/src/buffer.rs @@ -144,22 +144,21 @@ impl From<&SBuffer> for Buffer { } } -pub struct QBufWrap { - _inner: Box, // mem owner +pub struct QBufWrap { + _inner: Box, // mem owner v: Vec, } -unsafe impl Send for QBufWrap {} +unsafe impl Send for QBufWrap {} -impl QBufWrap { - pub fn new(buf: B) -> Self { +impl QBufWrap { + pub fn new(mut buf: Box) -> Self { // make on heap so that no ptr move. - let mut inner = Box::new(buf); - let v = Self::convert_buf(&mut inner); - Self { _inner: inner, v } + let v = Self::convert_buf(&mut buf); + Self { _inner: buf, v } } - fn convert_buf(b: &mut impl Buf) -> Vec { + fn convert_buf(b: &mut Box) -> Vec { let mut v = Vec::new(); // change buf to vecs while b.has_remaining() { @@ -188,6 +187,7 @@ impl QBytesMut { let mut res = BytesMut::new(); b.iter().for_each(|i| { let s = unsafe { slice::from_raw_parts(i.buffer, i.length.try_into().unwrap()) }; + res.reserve(s.len()); res.put_slice(s); }); Self(res) @@ -195,25 +195,28 @@ impl QBytesMut { } pub fn debug_buf_to_string(mut b: impl Buf) -> String { - let cp = b.copy_to_bytes(b.remaining()); - String::from_utf8_lossy(&cp).into_owned() + let mut dst = vec![0; b.remaining()]; + b.copy_to_slice(&mut dst[..]); + // let cp = b.copy_to_bytes(b.remaining()); + String::from_utf8_lossy(&dst).into_owned() +} + +pub fn debug_raw_buf_to_string(b: Buffer) -> String { + let s = String::from_utf8_lossy(unsafe { + slice::from_raw_parts(b.buffer, b.length.try_into().unwrap()) + }); + s.into_owned() } #[cfg(test)] mod test { - use core::slice; use bytes::{BufMut, Bytes, BytesMut}; use c2::Buffer; - use super::{QBufWrap, QBuffRef, QBufferVec, QVecBuffer}; + use crate::buffer::debug_raw_buf_to_string; - fn buf_to_string(b: Buffer) -> String { - let s = String::from_utf8_lossy(unsafe { - slice::from_raw_parts(b.buffer, b.length.try_into().unwrap()) - }); - s.into_owned() - } + use super::{debug_buf_to_string, QBufWrap, QBuffRef, QBufferVec, QBytesMut, QVecBuffer}; #[test] fn test_vec_buffer() { @@ -239,11 +242,11 @@ mod test { #[test] fn test_buf() { let b = Bytes::from("mydata"); - let wrap = QBufWrap::new(b); + let wrap = QBufWrap::new(Box::new(b)); let v = wrap.as_buffs(); assert_eq!(v.len(), 1); let b1 = v[0]; - let s = buf_to_string(b1); + let s = debug_raw_buf_to_string(b1); assert_eq!(s, "mydata"); } @@ -252,11 +255,21 @@ mod test { let mut b = BytesMut::with_capacity(5); b.put(&b"hello"[..]); b.put(&b"world"[..]); // this will grow - let wrap = QBufWrap::new(b); + let wrap = QBufWrap::new(Box::new(b)); let v = wrap.as_buffs(); assert_eq!(v.len(), 1); let b1 = v[0]; - let s = buf_to_string(b1); + let s = debug_raw_buf_to_string(b1); + assert_eq!(s, "helloworld"); + } + + #[test] + fn test_buf2str() { + let args: [QVecBuffer; 2] = [QVecBuffer::from("hello"), QVecBuffer::from("world")]; + let buffer_vec = QBufferVec::from(args.as_slice()); + let bm = QBytesMut::from_buffs(buffer_vec.as_buffers()); + std::mem::drop(args); + let s = debug_buf_to_string(bm.0); assert_eq!(s, "helloworld"); } } diff --git a/crates/libs/msquic/src/lib.rs b/crates/libs/msquic/src/lib.rs index 1c3feaa..dedf422 100644 --- a/crates/libs/msquic/src/lib.rs +++ b/crates/libs/msquic/src/lib.rs @@ -15,6 +15,8 @@ pub mod stream; pub mod sync; mod utils; +//pub mod msh3; + // Some useful defs pub const QUIC_STATUS_PENDING: u32 = 0x703e5; pub const QUIC_STATUS_SUCCESS: u32 = 0; diff --git a/crates/libs/msquic/src/msh3/mod.rs b/crates/libs/msquic/src/msh3/mod.rs new file mode 100644 index 0000000..674a39f --- /dev/null +++ b/crates/libs/msquic/src/msh3/mod.rs @@ -0,0 +1,210 @@ +// h3 wrappings for msquic + +use std::{ + fmt::Display, + future::Future, + pin::{self, pin}, + sync::Arc, + task::Poll, +}; + +use bytes::{Buf, BytesMut}; +use c2::SEND_FLAG_NONE; +use h3::quic::{BidiStream, Connection, OpenStreams, RecvStream, SendStream, StreamId}; + +use crate::{conn::QConnection, stream::QStream}; + +#[derive(Debug)] +pub struct H3Error { + status: std::io::Error, + error_code: Option, +} + +impl H3Error{ + pub fn new(status: std::io::Error, ec: Option) -> Self{ + Self { status, error_code: ec } + } +} + +impl h3::quic::Error for H3Error { + fn is_timeout(&self) -> bool { + self.status.kind() == std::io::ErrorKind::TimedOut + } + + fn err_code(&self) -> Option { + self.error_code + } +} + +impl std::error::Error for H3Error {} + +impl Display for H3Error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + todo!() + } +} + +pub struct H3Conn { + inner: QConnection, +} + +impl OpenStreams for H3Conn { + type BidiStream = H3Stream; + + type SendStream = H3Stream; + + type RecvStream = H3Stream; + + type Error = H3Error; + + fn poll_open_bidi( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + todo!() + } + + fn poll_open_send( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + todo!() + } + + fn close(&mut self, code: h3::error::Code, reason: &[u8]) { + todo!() + } +} + +impl Connection for H3Conn { + type BidiStream = H3Stream; + + type SendStream = H3Stream; + + type RecvStream = H3Stream; + + type OpenStreams = H3Conn; + + type Error = H3Error; + + fn poll_accept_recv( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll, Self::Error>> { + todo!() + } + + fn poll_accept_bidi( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll, Self::Error>> { + todo!() + } + + fn poll_open_bidi( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + todo!() + } + + fn poll_open_send( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + todo!() + } + + fn opener(&self) -> Self::OpenStreams { + todo!() + } + + fn close(&mut self, code: h3::error::Code, reason: &[u8]) { + todo!() + } +} + +pub struct H3Stream { + inner: QStream, + id: h3::quic::StreamId, + //read: +} + +impl H3Stream { + fn new(s: QStream, id: StreamId) -> Self { + Self { inner: s, id } + } +} + +impl SendStream for H3Stream { + type Error = H3Error; + + fn poll_ready( + &mut self, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + // always ready to send? + // convert this to open or start? + // if send is in progress? + Poll::Ready(Ok(())) + } + + fn send_data>>(&mut self, data: T) -> Result<(), Self::Error> { + let b: h3::quic::WriteBuf = data.into(); + self.inner.send_only(b, SEND_FLAG_NONE); + Ok(()) + } + + fn poll_finish( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.inner.poll_send(cx).map_err(|e|{H3Error::new(e, None)}) + } + + fn reset(&mut self, _reset_code: u64) { + panic!("reset not supported") + } + + fn send_id(&self) -> h3::quic::StreamId { + self.id + } +} + +impl RecvStream for H3Stream { + type Buf = BytesMut; + + type Error = H3Error; + + fn poll_data( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll, Self::Error>> { + // let fu = self.inner.receive(); + // let innner = as Future>::poll(Pin::new(&mut self.rx), _cx); + //Pin::new(&mut fu).poll(cx); + // let mut pinned_fut = pin!(fu); + // pinned_fut.poll(cx); + todo!() + } + + fn stop_sending(&mut self, error_code: u64) { + self.inner.stop_sending(error_code); + } + + fn recv_id(&self) -> h3::quic::StreamId { + self.id + } +} + +impl BidiStream for H3Stream { + type SendStream = H3Stream; + + type RecvStream = H3Stream; + + fn split(self) -> (Self::SendStream, Self::RecvStream) { + let cp = self.inner.clone(); + let id = self.id.clone(); + (self, H3Stream::new(cp, id)) + } +} diff --git a/crates/libs/msquic/src/stream.rs b/crates/libs/msquic/src/stream.rs index fdfd342..6a3f7d0 100644 --- a/crates/libs/msquic/src/stream.rs +++ b/crates/libs/msquic/src/stream.rs @@ -1,15 +1,17 @@ use std::{ ffi::c_void, + future::poll_fn, io::{Error, ErrorKind}, slice, - sync::Mutex, + sync::{Arc, Mutex, MutexGuard}, + task::Poll, }; use crate::{ - buffer::{QBufWrap, QBytesMut}, + buffer::{debug_buf_to_string, debug_raw_buf_to_string, QBufWrap, QBytesMut}, conn::QConnection, info, - sync::{QQueue, QResetChannel, QSignal}, + sync::{QSignal, QWakableQueue, QWakableSig}, utils::SBox, QApi, }; @@ -19,14 +21,14 @@ use c2::{ STREAM_EVENT_PEER_RECEIVE_ABORTED, STREAM_EVENT_PEER_SEND_ABORTED, STREAM_EVENT_PEER_SEND_SHUTDOWN, STREAM_EVENT_RECEIVE, STREAM_EVENT_SEND_COMPLETE, STREAM_EVENT_SEND_SHUTDOWN_COMPLETE, STREAM_EVENT_SHUTDOWN_COMPLETE, - STREAM_EVENT_START_COMPLETE, STREAM_SHUTDOWN_FLAG_NONE, + STREAM_EVENT_START_COMPLETE, STREAM_SHUTDOWN_FLAG_GRACEFUL, STREAM_SHUTDOWN_FLAG_NONE, }; -// #[derive(Debug)] +#[derive(Clone)] pub struct QStream { _api: QApi, - inner: SBox, - ctx: Box>, + inner: Arc>, // arc needed for copy + ctx: Arc>, } #[derive(Debug, Clone)] @@ -41,30 +43,30 @@ enum StartPayload { } struct QStreamCtx { - start_sig: QResetChannel, - receive_ch: QQueue, - send_ch: QResetChannel, + start_sig: QWakableQueue, + receive_ch: QWakableQueue, + send_sig: QWakableSig, send_shtdwn_sig: QSignal, drain_sig: QSignal, is_drained: bool, + pending_buf: Option, // because msquic copies buffers in background we need to hold the buffer temporarily } impl QStreamCtx { fn new() -> Self { Self { - start_sig: QResetChannel::new(), - receive_ch: QQueue::new(), - send_ch: QResetChannel::new(), + start_sig: QWakableQueue::default(), + receive_ch: QWakableQueue::default(), + send_sig: QWakableSig::default(), send_shtdwn_sig: QSignal::new(), drain_sig: QSignal::new(), is_drained: false, + pending_buf: None, } } fn on_start_complete(&mut self) { - if self.start_sig.can_set() { - self.start_sig.set(StartPayload::Success); - } + self.start_sig.insert(StartPayload::Success); } fn on_send_complete(&mut self, cancelled: bool) { let payload = if cancelled { @@ -72,23 +74,32 @@ impl QStreamCtx { } else { SentPayload::Success }; - if self.send_ch.can_set() { - self.send_ch.set(payload); - } + let prev = self.pending_buf.take(); // release buffer + assert!(prev.is_some()); + self.send_sig.set(payload); } fn on_receive(&mut self, buffs: &[Buffer]) { // send to frontend let v = QBytesMut::from_buffs(buffs); + let s = debug_buf_to_string(v.0.clone()); + let original = debug_raw_buf_to_string(buffs[0]); + info!( + "debug: receive bytes: {} len:{}, original {}, len: {}", + s, + s.len(), + original, + original.len() + ); self.receive_ch.insert(v); } fn on_peer_send_shutdown(&mut self) { // peer can shutdown their direction. But we should receive what is pending. // Peer will no longer send new stuff, so the receive can be dropped. // if frontend is waiting stop it. - self.receive_ch.close(0); + self.receive_ch.close(); } fn on_peer_send_abort(&mut self, _ec: u64) { - self.receive_ch.close(0); + self.receive_ch.close(); } fn on_send_shutdown_complete(&mut self) { if self.send_shtdwn_sig.can_set() { @@ -97,7 +108,7 @@ impl QStreamCtx { } fn on_shutdown_complete(&mut self) { // close all channels - self.receive_ch.close(0); + self.receive_ch.close(); // drain signal self.is_drained = true; if self.drain_sig.can_set() { @@ -131,11 +142,16 @@ extern "C" fn qstream_handler_callback( ctx.on_send_complete(raw.canceled); } STREAM_EVENT_RECEIVE => { - info!("[{:?}] QUIC_STREAM_EVENT_RECEIVE", stream); let raw = unsafe { event.payload.receive }; let count = raw.buffer_count; let curr = raw.buffer; let buffs = unsafe { slice::from_raw_parts(curr, count.try_into().unwrap()) }; + info!( + "[{:?}] QUIC_STREAM_EVENT_RECEIVE: buffer count {}, len {}", + stream, + buffs.len(), + buffs[0].length + ); ctx.on_receive(buffs); } STREAM_EVENT_PEER_SEND_SHUTDOWN => { @@ -172,7 +188,7 @@ extern "C" fn qstream_handler_callback( impl QStream { pub fn attach(api: QApi, h: Handle) -> Self { let s = Stream::from_parts(h, &api.inner.inner); - let ctx = Box::new(Mutex::new(QStreamCtx::new())); + let ctx = Arc::new(Mutex::new(QStreamCtx::new())); s.set_callback_handler( qstream_handler_callback, &*ctx as *const Mutex as *const c_void, @@ -180,7 +196,7 @@ impl QStream { Self { _api: api, - inner: SBox::new(s), + inner: Arc::new(SBox::new(s)), ctx, } } @@ -188,7 +204,7 @@ impl QStream { // open client stream pub fn open(connection: &QConnection, flags: StreamOpenFlags) -> Self { let s = Stream::new(&connection._api.inner.inner); - let ctx = Box::new(Mutex::new(QStreamCtx::new())); + let ctx = Arc::new(Mutex::new(QStreamCtx::new())); s.open( &connection.inner.inner, flags, @@ -198,61 +214,116 @@ impl QStream { Self { _api: connection._api.clone(), ctx, - inner: SBox::new(s), + inner: Arc::new(SBox::new(s)), + } + } + + pub fn start_only(&self, flags: StreamStartFlags) { + self.inner.inner.start(flags); + } + + pub fn poll_start( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + let p = self.ctx.lock().unwrap().start_sig.poll(cx); + match p { + std::task::Poll::Ready(op) => match op { + Some(_) => Poll::Ready(Ok(())), + None => Poll::Ready(Err(Error::from(ErrorKind::BrokenPipe))), + }, + std::task::Poll::Pending => Poll::Pending, } } // start stream for client pub async fn start(&mut self, flags: StreamStartFlags) -> Result<(), Error> { // regardless of start success of fail, there is a QUIC_STREAM_EVENT_START_COMPLETE callback. - let rx; - { - // prepare the channel. - rx = self.ctx.lock().unwrap().start_sig.reset(); - self.inner.inner.start(flags); - } - // wait for backend - match rx.await { - StartPayload::Success => Ok(()), + self.start_only(flags); + let fu = poll_fn(|cx| self.poll_start(cx)); + fu.await + } + + pub fn poll_receive( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + let p = self.ctx.lock().unwrap().receive_ch.poll(cx); + match p { + Poll::Ready(op) => match op { + Some(b) => Poll::Ready(Ok(b.0)), + None => Poll::Ready(Err(Error::from(ErrorKind::BrokenPipe))), + }, + Poll::Pending => Poll::Pending, } } // receive into this buff // return num of bytes wrote. pub async fn receive(&mut self) -> Result { - let rx; - { - rx = self.ctx.lock().unwrap().receive_ch.pop(); - } - - let v = rx - .await - .map_err(|e: u32| Error::from_raw_os_error(e.try_into().unwrap()))?; - Ok(v.0) + let fu = poll_fn(|cx| self.poll_receive(cx)); + fu.await } // fn receive_complete(&self, len: u64) { // // TODO: handle error // let _ = self.inner.inner.receive_complete(len); // } + pub fn send_only(&mut self, buffers: impl Buf + 'static, flags: SendFlags) { + let mut lk = self.ctx.lock().unwrap(); + lk.send_sig.set_frontend_pending(); + let b = QBufWrap::new(Box::new(buffers)); + // hold on the buffer until callback. + let prev = lk.pending_buf.replace(b); + assert!(prev.is_none()); + let bb = lk.pending_buf.as_ref().unwrap().as_buffs(); + self.inner + .inner + .send(&bb[0], bb.len() as u32, flags, std::ptr::null()); + } - pub async fn send(&mut self, buffers: impl Buf, flags: SendFlags) -> Result<(), Error> { - let b = QBufWrap::new(buffers); - let rx; - { - let bb = b.as_buffs(); - rx = self.ctx.lock().unwrap().send_ch.reset(); - self.inner - .inner - .send(&bb[0], bb.len() as u32, flags, std::ptr::null()); + pub fn poll_send(&mut self, cx: &mut std::task::Context<'_>) -> Poll> { + let mut p = self.ctx.lock().unwrap(); + Self::poll_send_inner(&mut p, cx) + } + + fn poll_send_inner( + lk: &mut MutexGuard, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let p = lk.send_sig.poll(cx); + match p { + std::task::Poll::Ready(op) => match op { + Some(e) => match e { + SentPayload::Success => Poll::Ready(Ok(())), + SentPayload::Canceled => { + Poll::Ready(Err(Error::from(ErrorKind::ConnectionAborted))) + } + }, + None => Poll::Ready(Err(Error::from(ErrorKind::BrokenPipe))), + }, + std::task::Poll::Pending => Poll::Pending, } + } - // wait backend - let res = rx.await; - match res { - SentPayload::Success => Ok(()), - SentPayload::Canceled => Err(Error::from(ErrorKind::ConnectionAborted)), + pub async fn send( + &mut self, + buffers: impl Buf + 'static, + flags: SendFlags, + ) -> Result<(), Error> { + self.send_only(buffers, flags); + let fu = poll_fn(|cx| self.poll_send(cx)); + fu.await + } + + // poll if send is ready for more data + pub fn poll_ready_send(&mut self, cx: &mut std::task::Context<'_>) -> Poll> { + let mut lk = self.ctx.lock().unwrap(); + // If frontend pending is not yet cleared from backend, we set waker in backend. + if !lk.send_sig.is_frontend_pending() { + return Poll::Ready(Ok(())); } + Self::poll_send_inner(&mut lk, cx) } // send shutdown signal to peer. @@ -266,6 +337,14 @@ impl QStream { rx.await; } + // this is for h3 where the interface does not wait + // We will ignore callback. + pub fn stop_sending(&self, error_code: u64) { + self.inner + .inner + .shutdown(STREAM_SHUTDOWN_FLAG_GRACEFUL, error_code) + } + // wait for the complete shutdown event. before close handle. pub async fn drain(&mut self) { let rx; diff --git a/crates/libs/msquic/src/sync.rs b/crates/libs/msquic/src/sync.rs index 40ee959..ab431e8 100644 --- a/crates/libs/msquic/src/sync.rs +++ b/crates/libs/msquic/src/sync.rs @@ -2,7 +2,8 @@ use std::{ collections::LinkedList, future::Future, pin::Pin, - task::{Context, Poll}, + sync::{Arc, Mutex}, + task::{Context, Poll, Waker}, }; use tokio::sync::oneshot::{self, Receiver}; @@ -161,3 +162,248 @@ impl QQueue { } } } + +#[derive(Default)] +struct QWakableQueueState { + data: LinkedList, + waker: Option, + is_closed: bool, +} + +#[derive(Clone)] +pub struct QWakableQueue { + state: Arc>>, +} + +impl Default for QWakableQueue { + fn default() -> Self { + let state = QWakableQueueState { + data: LinkedList::new(), + waker: None, + is_closed: false, + }; + Self { + state: Arc::new(Mutex::new(state)), + } + } +} + +impl QWakableQueue { + // insert the data and wake + pub fn insert(&mut self, data: T) { + let mut lk = self.state.lock().unwrap(); + if lk.is_closed { + panic!("set after close") + } + lk.data.push_back(data); + if lk.waker.is_some() { + lk.waker.take().unwrap().wake(); + } + } + + // if polled none the res is cancelled and will not be delivered. + // only one poll can happen at a time. + pub fn poll(&mut self, cx: &mut std::task::Context<'_>) -> std::task::Poll> { + let mut lk = self.state.lock().unwrap(); + match lk.data.pop_front() { + Some(d) => Poll::Ready(Some(d)), + None => { + if lk.is_closed { + Poll::Ready(None) + } else { + // register waker + let prev = lk.waker.replace(cx.waker().clone()); + assert!(prev.is_none()); + Poll::Pending + } + } + } + } + + pub fn close(&mut self) { + let mut lk = self.state.lock().unwrap(); + if lk.is_closed { + return; + } + lk.is_closed = true; + // ask for poll + if let Some(w) = lk.waker.take() { + w.wake(); + } + } +} + +struct QWakableSigState { + data: Option, + wakers: LinkedList, // multiple waker can register + is_closed: bool, + frontend_pending: bool, +} + +impl Default for QWakableSigState { + fn default() -> Self { + Self { + data: Default::default(), + wakers: Default::default(), + is_closed: Default::default(), + frontend_pending: Default::default(), + } + } +} + +pub struct QWakableSig { + inner: Arc>>, +} + +impl Default for QWakableSig { + fn default() -> Self { + Self { + inner: Default::default(), + } + } +} + +impl QWakableSig { + // frontend has action pending. For example frontend initiated send + pub fn set_frontend_pending(&mut self) { + let mut lk = self.inner.lock().unwrap(); + assert!(!lk.frontend_pending); + lk.frontend_pending = true; + } + + pub fn is_frontend_pending(&self) -> bool { + let lk = self.inner.lock().unwrap(); + lk.frontend_pending + } + + // if none is returned, means the sig is cannceled. + // Multiple waker can poll. + pub fn poll(&mut self, cx: &mut std::task::Context<'_>) -> std::task::Poll> { + let mut lk = self.inner.lock().unwrap(); + match &lk.data { + Some(s) => Poll::Ready(Some(s.clone())), + None => { + if lk.is_closed { + Poll::Ready(None) + } else { + // save waker + lk.wakers.push_back(cx.waker().clone()); + Poll::Pending + } + } + } + } + + pub fn set(&mut self, data: T) { + let mut lk = self.inner.lock().unwrap(); + if lk.data.is_some() { + return; // already set + } + if lk.is_closed { + panic!("set after close"); + } + lk.data.replace(data); + lk.frontend_pending = false; // the set corresponds to front end action, and we clear it here. + // wake all wakers + while let Some(w) = lk.wakers.pop_front() { + w.wake(); + } + } + + // reset to default state. + pub fn reset(&mut self) { + let mut lk = self.inner.lock().unwrap(); + if lk.is_closed { + panic!("reset after close"); + } + if !lk.wakers.is_empty() { + panic!("reset while waker is pending"); + } + lk.data = None; + } + + pub fn close(&mut self) { + let mut lk = self.inner.lock().unwrap(); + if lk.is_closed { + return; + } + lk.is_closed = true; + // wake all waker + while let Some(w) = lk.wakers.pop_front() { + w.wake(); + } + } +} + +#[cfg(test)] +mod tests { + use std::{ + future::poll_fn, + sync::atomic::AtomicUsize, + task::{Context, Poll}, + time::Duration, + }; + + use crate::sync::QWakableQueue; + + static COUNTER: AtomicUsize = AtomicUsize::new(0); + fn read_line(_cx: &mut Context<'_>) -> Poll { + println!("readline called"); + // the second poll should work + if COUNTER.fetch_add(1, std::sync::atomic::Ordering::Acquire) < 1 { + _cx.waker().clone().wake(); + Poll::Pending + } else { + Poll::Ready("Hello, World!".into()) + } + } + + #[tokio::test] + async fn poll_test() { + let read_future = poll_fn(read_line); + assert_eq!(read_future.await, "Hello, World!".to_owned()); + } + + #[tokio::test] + async fn wake_test() { + // set in same thread + let mut wakable = QWakableQueue::default(); + wakable.insert(String::from("hello")); + let fu = poll_fn(|cx| wakable.poll(cx)); + let out = fu.await.unwrap(); + assert_eq!(out, "hello"); + } + + #[tokio::test] + async fn wake_test2() { + // set from different task + let mut wakable = QWakableQueue::default(); + let mut w_cp = wakable.clone(); + tokio::spawn(async move { + tokio::time::sleep(Duration::from_millis(10)).await; + w_cp.insert(String::from("hello")); + w_cp.insert(String::from("hello2")); + }); + let mut wakable_cp = wakable.clone(); + let fu = poll_fn(|cx| wakable_cp.poll(cx)); + let out = fu.await.unwrap(); + assert_eq!(out, "hello"); + let fu2 = poll_fn(|cx| wakable.poll(cx)); + let out2 = fu2.await.unwrap(); + assert_eq!(out2, "hello2"); + } + + #[tokio::test] + async fn wake_test3() { + // close + let mut wakable: QWakableQueue = Default::default(); + let mut w_cp = wakable.clone(); + tokio::spawn(async move { + tokio::time::sleep(Duration::from_millis(10)).await; + w_cp.close(); + }); + let fu = poll_fn(|cx| wakable.poll(cx)); + let out = fu.await; + assert!(out.is_none()); + } +}