diff --git a/crates/libs/c2/src/lib.rs b/crates/libs/c2/src/lib.rs index 82ba827..599d8bb 100644 --- a/crates/libs/c2/src/lib.rs +++ b/crates/libs/c2/src/lib.rs @@ -22,6 +22,7 @@ use libc::c_void; use serde::{Deserialize, Serialize}; use std::convert::TryInto; use std::fmt; +use std::mem::size_of_val; use std::option::Option; use std::ptr; #[macro_use] @@ -1739,6 +1740,20 @@ impl Stream { ); assert!(Status::succeeded(status), "Code: 0x{:x}", status); } + + // get stream id + pub fn get_id(&self) -> u64 { + let mut id: u64 = 0; + let size = size_of_val(&id) as u32; + let status = (unsafe { self.table.as_ref().unwrap().get_param })( + self.handle, + PARAM_STREAM_ID, + std::ptr::addr_of!(size) as *mut u32, + std::ptr::addr_of_mut!(id) as *mut c_void, + ); + assert!(Status::succeeded(status), "Code: 0x{:x}", status); + id + } } impl Drop for Stream { diff --git a/crates/libs/msquic/src/conn.rs b/crates/libs/msquic/src/conn.rs index 58d5dfa..9bf9f5f 100644 --- a/crates/libs/msquic/src/conn.rs +++ b/crates/libs/msquic/src/conn.rs @@ -3,9 +3,9 @@ use crate::{ info, reg::QRegistration, stream::QStream, - sync::{QQueue, QReceiver, QResetChannel, QSignal}, + sync::{QQueue, QReceiver, QResetChannel, QWakableSig}, }; -use std::{ffi::c_void, fmt::Debug, io::Error, sync::Mutex}; +use std::{ffi::c_void, fmt::Debug, future::poll_fn, io::Error, sync::Mutex, task::Poll}; use c2::{ Configuration, Connection, ConnectionEvent, Handle, SendResumptionFlags, @@ -51,7 +51,7 @@ enum ConnStatus { struct QConnectionCtx { _api: QApi, strm_ch: QQueue, - shtdwn_sig: QSignal, + shtdwn_sig: QWakableSig<()>, //state: Mutex, conn_ch: QResetChannel, // handle connect success or transport close proceed_rx: Option>, // used for server wait conn @@ -135,7 +135,7 @@ impl QConnectionCtx { Self { _api: api.clone(), strm_ch: QQueue::new(), - shtdwn_sig: QSignal::new(), + shtdwn_sig: QWakableSig::default(), conn_ch: QResetChannel::new(), proceed_rx: None, } @@ -165,9 +165,7 @@ impl QConnectionCtx { } fn on_shutdown_complete(&mut self) { self.strm_ch.close(0); - if self.shtdwn_sig.can_set() { - self.shtdwn_sig.set(()); - } + self.shtdwn_sig.set(()); } fn on_peer_stream_started(&mut self, h: Handle) { let s = QStream::attach(self._api.clone(), h); @@ -269,16 +267,28 @@ impl QConnection { } } - pub async fn shutdown(&mut self) { - let rx; + pub fn shutdown_only(&self, ec: u64) { + self.inner.inner.shutdown(CONNECTION_SHUTDOWN_FLAG_NONE, ec); // ec + } + + pub fn poll_shutdown(&mut self, cx: &mut std::task::Context<'_>) -> Poll<()> { + let mut lk = self.ctx.lock().unwrap(); + let p = lk.shtdwn_sig.poll(cx); + match p { + Poll::Ready(_) => Poll::Ready(()), + Poll::Pending => Poll::Pending, + } + } + + pub async fn shutdown(&mut self, ec: u64) { { - rx = self.ctx.lock().unwrap().shtdwn_sig.reset(); + let mut lk = self.ctx.lock().unwrap(); + if !lk.shtdwn_sig.is_frontend_pending() { + lk.shtdwn_sig.set_frontend_pending(); + self.shutdown_only(ec) + } } - info!("conn invoke shutdown"); - // callback maybe sync - self.inner.inner.shutdown(CONNECTION_SHUTDOWN_FLAG_NONE, 0); // ec - info!("conn wait for shutdown evnet"); - rx.await; - info!("conn wait for shutdown evnet end"); + let fu = poll_fn(|cx| self.poll_shutdown(cx)); + fu.await } } diff --git a/crates/libs/msquic/src/lib.rs b/crates/libs/msquic/src/lib.rs index 9f85279..d468109 100644 --- a/crates/libs/msquic/src/lib.rs +++ b/crates/libs/msquic/src/lib.rs @@ -168,7 +168,7 @@ mod tests { rth.spawn(async move { info!("server accepted stream"); let mut s = s.unwrap(); - info!("server stream receive"); + info!("server stream {} receive", s.get_id()); let read = s.receive().await.unwrap(); let payload = debug_buf_to_string(read); info!("server received len {}", payload.len()); @@ -182,7 +182,7 @@ mod tests { }); } info!("server conn shutdown"); - conn.shutdown().await; + conn.shutdown(0).await; info!("server conn shutdown end"); }); } @@ -224,7 +224,7 @@ mod tests { info!("client stream start"); st.start(STREAM_START_FLAG_NONE).await.unwrap(); let args = Bytes::from("hello"); - info!("client stream send"); + info!("client stream {} send", st.get_id()); st.send(args, SEND_FLAG_FIN).await.unwrap(); info!("client stream receive"); @@ -235,7 +235,7 @@ mod tests { info!("client stream drain"); st.drain().await; info!("client conn shutdown"); - conn.shutdown().await; + conn.shutdown(0).await; // shutdown server sht_tx.send(()).unwrap(); }); diff --git a/crates/libs/msquic/src/msh3/mod.rs b/crates/libs/msquic/src/msh3/mod.rs index 58d3ad5..e72a4fb 100644 --- a/crates/libs/msquic/src/msh3/mod.rs +++ b/crates/libs/msquic/src/msh3/mod.rs @@ -1,10 +1,10 @@ // h3 wrappings for msquic -use std::fmt::Display; +use std::{fmt::Display, task::Poll}; use bytes::{Buf, BytesMut}; -use c2::SEND_FLAG_NONE; -use h3::quic::{BidiStream, Connection, OpenStreams, RecvStream, SendStream, StreamId}; +use c2::{SEND_FLAG_NONE, STREAM_OPEN_FLAG_NONE, STREAM_OPEN_FLAG_UNIDIRECTIONAL}; +use h3::quic::{BidiStream, Connection, OpenStreams, RecvStream, SendStream}; use crate::{conn::QConnection, stream::QStream}; @@ -58,18 +58,23 @@ impl OpenStreams for H3Conn { &mut self, _cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { - todo!() + let s = QStream::open(&self._inner, STREAM_OPEN_FLAG_NONE); + // TODO: start? + Poll::Ready(Ok(H3Stream::new(s))) } fn poll_open_send( &mut self, _cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { - todo!() + let s = QStream::open(&self._inner, STREAM_OPEN_FLAG_UNIDIRECTIONAL); + // TODO: start? + Poll::Ready(Ok(H3Stream::new(s))) } - fn close(&mut self, _code: h3::error::Code, _reason: &[u8]) { - todo!() + fn close(&mut self, code: h3::error::Code, _reason: &[u8]) { + // TODO? + self._inner.shutdown_only(code.value()) } } @@ -116,22 +121,20 @@ impl Connection for H3Conn { todo!() } - fn close(&mut self, _code: h3::error::Code, _reason: &[u8]) { - todo!() + fn close(&mut self, code: h3::error::Code, _reason: &[u8]) { + self._inner.shutdown_only(code.value()) } } pub struct H3Stream { inner: QStream, - id: h3::quic::StreamId, shutdown: bool, } impl H3Stream { - fn new(s: QStream, id: StreamId) -> Self { + fn new(s: QStream) -> Self { Self { inner: s, - id, shutdown: false, } } @@ -175,7 +178,7 @@ impl SendStream for H3Stream { } fn send_id(&self) -> h3::quic::StreamId { - self.id + self.inner.get_id().try_into().expect("cannot convert id") } } @@ -184,16 +187,15 @@ impl RecvStream for H3Stream { type Error = H3Error; + // currently error is not propagated. fn poll_data( &mut self, - _cx: &mut std::task::Context<'_>, + cx: &mut std::task::Context<'_>, ) -> std::task::Poll, Self::Error>> { - // let fu = self.inner.receive(); - // let innner = as Future>::poll(Pin::new(&mut self.rx), _cx); - //Pin::new(&mut fu).poll(cx); - // let mut pinned_fut = pin!(fu); - // pinned_fut.poll(cx); - todo!() + match self.inner.poll_receive(cx) { + std::task::Poll::Ready(br) => Poll::Ready(Ok(br)), + std::task::Poll::Pending => Poll::Pending, + } } fn stop_sending(&mut self, error_code: u64) { @@ -201,7 +203,8 @@ impl RecvStream for H3Stream { } fn recv_id(&self) -> h3::quic::StreamId { - self.id + let id = self.inner.get_id(); + id.try_into().expect("invalid stream id") } } @@ -212,7 +215,6 @@ impl BidiStream for H3Stream { fn split(self) -> (Self::SendStream, Self::RecvStream) { let cp = self.inner.clone(); - let id = self.id; - (self, H3Stream::new(cp, id)) + (self, H3Stream::new(cp)) } } diff --git a/crates/libs/msquic/src/stream.rs b/crates/libs/msquic/src/stream.rs index d5a3629..679da42 100644 --- a/crates/libs/msquic/src/stream.rs +++ b/crates/libs/msquic/src/stream.rs @@ -15,7 +15,7 @@ use crate::{ utils::SBox, QApi, }; -use bytes::Buf; +use bytes::{Buf, BytesMut}; use c2::{ Buffer, Handle, SendFlags, Stream, StreamEvent, StreamOpenFlags, StreamStartFlags, STREAM_EVENT_PEER_RECEIVE_ABORTED, STREAM_EVENT_PEER_SEND_ABORTED, @@ -242,15 +242,16 @@ impl QStream { fu.await } + // todo: propagate error pub fn poll_receive( &mut self, cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { + ) -> std::task::Poll> { let p = self.ctx.lock().unwrap().receive_ch.poll(cx); match p { Poll::Ready(op) => match op { - Some(b) => Poll::Ready(Ok(b.0)), - None => Poll::Ready(Err(Error::from(ErrorKind::BrokenPipe))), + Some(b) => Poll::Ready(Some(b.0)), + None => Poll::Ready(None), }, Poll::Pending => Poll::Pending, } @@ -258,7 +259,7 @@ impl QStream { // receive into this buff // return num of bytes wrote. - pub async fn receive(&mut self) -> Result { + pub async fn receive(&mut self) -> Option { let fu = poll_fn(|cx| self.poll_receive(cx)); fu.await } @@ -363,4 +364,9 @@ impl QStream { } rx.await; } + + // get stream id + pub fn get_id(&self) -> u64 { + self.inner.inner.get_id() + } }