Skip to content

Commit

Permalink
get stream id from stream handle for h3
Browse files Browse the repository at this point in the history
  • Loading branch information
youyuanwu committed Jun 14, 2024
1 parent d27aa0e commit 02d9f42
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 48 deletions.
15 changes: 15 additions & 0 deletions crates/libs/c2/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use libc::c_void;
use serde::{Deserialize, Serialize};
use std::convert::TryInto;
use std::fmt;
use std::mem::size_of_val;
use std::option::Option;
use std::ptr;
#[macro_use]
Expand Down Expand Up @@ -1739,6 +1740,20 @@ impl Stream {
);
assert!(Status::succeeded(status), "Code: 0x{:x}", status);
}

// get stream id
pub fn get_id(&self) -> u64 {
let mut id: u64 = 0;
let size = size_of_val(&id) as u32;
let status = (unsafe { self.table.as_ref().unwrap().get_param })(
self.handle,
PARAM_STREAM_ID,
std::ptr::addr_of!(size) as *mut u32,
std::ptr::addr_of_mut!(id) as *mut c_void,
);
assert!(Status::succeeded(status), "Code: 0x{:x}", status);
id
}
}

impl Drop for Stream {
Expand Down
42 changes: 26 additions & 16 deletions crates/libs/msquic/src/conn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ use crate::{
info,
reg::QRegistration,
stream::QStream,
sync::{QQueue, QReceiver, QResetChannel, QSignal},
sync::{QQueue, QReceiver, QResetChannel, QWakableSig},
};
use std::{ffi::c_void, fmt::Debug, io::Error, sync::Mutex};
use std::{ffi::c_void, fmt::Debug, future::poll_fn, io::Error, sync::Mutex, task::Poll};

use c2::{
Configuration, Connection, ConnectionEvent, Handle, SendResumptionFlags,
Expand Down Expand Up @@ -51,7 +51,7 @@ enum ConnStatus {
struct QConnectionCtx {
_api: QApi,
strm_ch: QQueue<QStream>,
shtdwn_sig: QSignal,
shtdwn_sig: QWakableSig<()>,
//state: Mutex<State>,
conn_ch: QResetChannel<ConnStatus>, // handle connect success or transport close
proceed_rx: Option<QReceiver<ConnStatus>>, // used for server wait conn
Expand Down Expand Up @@ -135,7 +135,7 @@ impl QConnectionCtx {
Self {
_api: api.clone(),
strm_ch: QQueue::new(),
shtdwn_sig: QSignal::new(),
shtdwn_sig: QWakableSig::default(),
conn_ch: QResetChannel::new(),
proceed_rx: None,
}
Expand Down Expand Up @@ -165,9 +165,7 @@ impl QConnectionCtx {
}
fn on_shutdown_complete(&mut self) {
self.strm_ch.close(0);
if self.shtdwn_sig.can_set() {
self.shtdwn_sig.set(());
}
self.shtdwn_sig.set(());
}
fn on_peer_stream_started(&mut self, h: Handle) {
let s = QStream::attach(self._api.clone(), h);
Expand Down Expand Up @@ -269,16 +267,28 @@ impl QConnection {
}
}

pub async fn shutdown(&mut self) {
let rx;
pub fn shutdown_only(&self, ec: u64) {
self.inner.inner.shutdown(CONNECTION_SHUTDOWN_FLAG_NONE, ec); // ec
}

pub fn poll_shutdown(&mut self, cx: &mut std::task::Context<'_>) -> Poll<()> {
let mut lk = self.ctx.lock().unwrap();
let p = lk.shtdwn_sig.poll(cx);
match p {
Poll::Ready(_) => Poll::Ready(()),
Poll::Pending => Poll::Pending,
}
}

pub async fn shutdown(&mut self, ec: u64) {
{
rx = self.ctx.lock().unwrap().shtdwn_sig.reset();
let mut lk = self.ctx.lock().unwrap();
if !lk.shtdwn_sig.is_frontend_pending() {
lk.shtdwn_sig.set_frontend_pending();
self.shutdown_only(ec)
}
}
info!("conn invoke shutdown");
// callback maybe sync
self.inner.inner.shutdown(CONNECTION_SHUTDOWN_FLAG_NONE, 0); // ec
info!("conn wait for shutdown evnet");
rx.await;
info!("conn wait for shutdown evnet end");
let fu = poll_fn(|cx| self.poll_shutdown(cx));
fu.await
}
}
8 changes: 4 additions & 4 deletions crates/libs/msquic/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ mod tests {
rth.spawn(async move {
info!("server accepted stream");
let mut s = s.unwrap();
info!("server stream receive");
info!("server stream {} receive", s.get_id());
let read = s.receive().await.unwrap();
let payload = debug_buf_to_string(read);
info!("server received len {}", payload.len());
Expand All @@ -182,7 +182,7 @@ mod tests {
});
}
info!("server conn shutdown");
conn.shutdown().await;
conn.shutdown(0).await;
info!("server conn shutdown end");
});
}
Expand Down Expand Up @@ -224,7 +224,7 @@ mod tests {
info!("client stream start");
st.start(STREAM_START_FLAG_NONE).await.unwrap();
let args = Bytes::from("hello");
info!("client stream send");
info!("client stream {} send", st.get_id());
st.send(args, SEND_FLAG_FIN).await.unwrap();

info!("client stream receive");
Expand All @@ -235,7 +235,7 @@ mod tests {
info!("client stream drain");
st.drain().await;
info!("client conn shutdown");
conn.shutdown().await;
conn.shutdown(0).await;
// shutdown server
sht_tx.send(()).unwrap();
});
Expand Down
48 changes: 25 additions & 23 deletions crates/libs/msquic/src/msh3/mod.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
// h3 wrappings for msquic

use std::fmt::Display;
use std::{fmt::Display, task::Poll};

use bytes::{Buf, BytesMut};
use c2::SEND_FLAG_NONE;
use h3::quic::{BidiStream, Connection, OpenStreams, RecvStream, SendStream, StreamId};
use c2::{SEND_FLAG_NONE, STREAM_OPEN_FLAG_NONE, STREAM_OPEN_FLAG_UNIDIRECTIONAL};
use h3::quic::{BidiStream, Connection, OpenStreams, RecvStream, SendStream};

use crate::{conn::QConnection, stream::QStream};

Expand Down Expand Up @@ -58,18 +58,23 @@ impl<B: Buf> OpenStreams<B> for H3Conn {
&mut self,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<Self::BidiStream, Self::Error>> {
todo!()
let s = QStream::open(&self._inner, STREAM_OPEN_FLAG_NONE);
// TODO: start?
Poll::Ready(Ok(H3Stream::new(s)))
}

fn poll_open_send(
&mut self,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<Self::SendStream, Self::Error>> {
todo!()
let s = QStream::open(&self._inner, STREAM_OPEN_FLAG_UNIDIRECTIONAL);
// TODO: start?
Poll::Ready(Ok(H3Stream::new(s)))
}

fn close(&mut self, _code: h3::error::Code, _reason: &[u8]) {
todo!()
fn close(&mut self, code: h3::error::Code, _reason: &[u8]) {
// TODO?
self._inner.shutdown_only(code.value())
}
}

Expand Down Expand Up @@ -116,22 +121,20 @@ impl<B: Buf> Connection<B> for H3Conn {
todo!()
}

fn close(&mut self, _code: h3::error::Code, _reason: &[u8]) {
todo!()
fn close(&mut self, code: h3::error::Code, _reason: &[u8]) {
self._inner.shutdown_only(code.value())
}
}

pub struct H3Stream {
inner: QStream,
id: h3::quic::StreamId,
shutdown: bool,
}

impl H3Stream {
fn new(s: QStream, id: StreamId) -> Self {
fn new(s: QStream) -> Self {
Self {
inner: s,
id,
shutdown: false,
}
}
Expand Down Expand Up @@ -175,7 +178,7 @@ impl<B: Buf> SendStream<B> for H3Stream {
}

fn send_id(&self) -> h3::quic::StreamId {
self.id
self.inner.get_id().try_into().expect("cannot convert id")
}
}

Expand All @@ -184,24 +187,24 @@ impl RecvStream for H3Stream {

type Error = H3Error;

// currently error is not propagated.
fn poll_data(
&mut self,
_cx: &mut std::task::Context<'_>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<Option<Self::Buf>, Self::Error>> {
// let fu = self.inner.receive();
// let innner = <Receiver<T> as Future>::poll(Pin::new(&mut self.rx), _cx);
//Pin::new(&mut fu).poll(cx);
// let mut pinned_fut = pin!(fu);
// pinned_fut.poll(cx);
todo!()
match self.inner.poll_receive(cx) {
std::task::Poll::Ready(br) => Poll::Ready(Ok(br)),
std::task::Poll::Pending => Poll::Pending,
}
}

fn stop_sending(&mut self, error_code: u64) {
self.inner.stop_sending(error_code);
}

fn recv_id(&self) -> h3::quic::StreamId {
self.id
let id = self.inner.get_id();
id.try_into().expect("invalid stream id")
}
}

Expand All @@ -212,7 +215,6 @@ impl<B: Buf> BidiStream<B> for H3Stream {

fn split(self) -> (Self::SendStream, Self::RecvStream) {
let cp = self.inner.clone();
let id = self.id;
(self, H3Stream::new(cp, id))
(self, H3Stream::new(cp))
}
}
16 changes: 11 additions & 5 deletions crates/libs/msquic/src/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use crate::{
utils::SBox,
QApi,
};
use bytes::Buf;
use bytes::{Buf, BytesMut};
use c2::{
Buffer, Handle, SendFlags, Stream, StreamEvent, StreamOpenFlags, StreamStartFlags,
STREAM_EVENT_PEER_RECEIVE_ABORTED, STREAM_EVENT_PEER_SEND_ABORTED,
Expand Down Expand Up @@ -242,23 +242,24 @@ impl QStream {
fu.await
}

// todo: propagate error
pub fn poll_receive(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<impl Buf, Error>> {
) -> std::task::Poll<Option<BytesMut>> {
let p = self.ctx.lock().unwrap().receive_ch.poll(cx);
match p {
Poll::Ready(op) => match op {
Some(b) => Poll::Ready(Ok(b.0)),
None => Poll::Ready(Err(Error::from(ErrorKind::BrokenPipe))),
Some(b) => Poll::Ready(Some(b.0)),
None => Poll::Ready(None),
},
Poll::Pending => Poll::Pending,
}
}

// receive into this buff
// return num of bytes wrote.
pub async fn receive(&mut self) -> Result<impl Buf, Error> {
pub async fn receive(&mut self) -> Option<BytesMut> {
let fu = poll_fn(|cx| self.poll_receive(cx));
fu.await
}
Expand Down Expand Up @@ -363,4 +364,9 @@ impl QStream {
}
rx.await;
}

// get stream id
pub fn get_id(&self) -> u64 {
self.inner.inner.get_id()
}
}

0 comments on commit 02d9f42

Please sign in to comment.