From ae3fdc6a3f5f051e7a3a16861a2f7749389b689c Mon Sep 17 00:00:00 2001 From: Youyuan Wu Date: Sat, 1 Jun 2024 11:11:41 -0700 Subject: [PATCH] get stream id param --- crates/libs/c2/src/lib.rs | 16 ++++++++++++++++ crates/libs/msquic/src/lib.rs | 4 ++-- crates/libs/msquic/src/msh3/mod.rs | 27 +++++++++++++++------------ crates/libs/msquic/src/stream.rs | 5 +++++ 4 files changed, 38 insertions(+), 14 deletions(-) diff --git a/crates/libs/c2/src/lib.rs b/crates/libs/c2/src/lib.rs index 82ba827..73b84e3 100644 --- a/crates/libs/c2/src/lib.rs +++ b/crates/libs/c2/src/lib.rs @@ -22,6 +22,8 @@ use libc::c_void; use serde::{Deserialize, Serialize}; use std::convert::TryInto; use std::fmt; +use std::mem::size_of; +use std::mem::size_of_val; use std::option::Option; use std::ptr; #[macro_use] @@ -1739,6 +1741,20 @@ impl Stream { ); assert!(Status::succeeded(status), "Code: 0x{:x}", status); } + + // get stream id + pub fn get_id(&self) -> u64 { + let mut id: u64 = 0; + let size = size_of_val(&id) as u32; + let status = (unsafe { self.table.as_ref().unwrap().get_param })( + self.handle, + PARAM_STREAM_ID, + std::ptr::addr_of!(size) as *mut u32, + std::ptr::addr_of_mut!(id) as *mut c_void, + ); + assert!(Status::succeeded(status), "Code: 0x{:x}", status); + id + } } impl Drop for Stream { diff --git a/crates/libs/msquic/src/lib.rs b/crates/libs/msquic/src/lib.rs index 391b2da..d468109 100644 --- a/crates/libs/msquic/src/lib.rs +++ b/crates/libs/msquic/src/lib.rs @@ -168,7 +168,7 @@ mod tests { rth.spawn(async move { info!("server accepted stream"); let mut s = s.unwrap(); - info!("server stream receive"); + info!("server stream {} receive", s.get_id()); let read = s.receive().await.unwrap(); let payload = debug_buf_to_string(read); info!("server received len {}", payload.len()); @@ -224,7 +224,7 @@ mod tests { info!("client stream start"); st.start(STREAM_START_FLAG_NONE).await.unwrap(); let args = Bytes::from("hello"); - info!("client stream send"); + info!("client stream {} send", st.get_id()); st.send(args, SEND_FLAG_FIN).await.unwrap(); info!("client stream receive"); diff --git a/crates/libs/msquic/src/msh3/mod.rs b/crates/libs/msquic/src/msh3/mod.rs index 1ac65b2..26edfff 100644 --- a/crates/libs/msquic/src/msh3/mod.rs +++ b/crates/libs/msquic/src/msh3/mod.rs @@ -3,7 +3,7 @@ use std::{fmt::Display, task::Poll}; use bytes::{Buf, BytesMut}; -use c2::SEND_FLAG_NONE; +use c2::{SEND_FLAG_NONE, STREAM_OPEN_FLAG_NONE, STREAM_OPEN_FLAG_UNIDIRECTIONAL}; use h3::quic::{BidiStream, Connection, OpenStreams, RecvStream, SendStream, StreamId}; use crate::{conn::QConnection, stream::QStream}; @@ -58,18 +58,23 @@ impl OpenStreams for H3Conn { &mut self, _cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { - todo!() + let s = QStream::open(&self._inner, STREAM_OPEN_FLAG_NONE); + // TODO: start? + Poll::Ready(Ok(H3Stream::new(s))) } fn poll_open_send( &mut self, _cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { - todo!() + let s = QStream::open(&self._inner, STREAM_OPEN_FLAG_UNIDIRECTIONAL); + // TODO: start? + Poll::Ready(Ok(H3Stream::new(s))) } - fn close(&mut self, _code: h3::error::Code, _reason: &[u8]) { - todo!() + fn close(&mut self, code: h3::error::Code, _reason: &[u8]) { + // TODO? + self._inner.shutdown_only(code.value()) } } @@ -123,15 +128,13 @@ impl Connection for H3Conn { pub struct H3Stream { inner: QStream, - id: h3::quic::StreamId, shutdown: bool, } impl H3Stream { - fn new(s: QStream, id: StreamId) -> Self { + fn new(s: QStream) -> Self { Self { inner: s, - id, shutdown: false, } } @@ -175,7 +178,7 @@ impl SendStream for H3Stream { } fn send_id(&self) -> h3::quic::StreamId { - self.id + self.inner.get_id().try_into().expect("cannot convert id") } } @@ -200,7 +203,8 @@ impl RecvStream for H3Stream { } fn recv_id(&self) -> h3::quic::StreamId { - self.id + let id = self.inner.get_id(); + id.try_into().expect("invalid stream id") } } @@ -211,7 +215,6 @@ impl BidiStream for H3Stream { fn split(self) -> (Self::SendStream, Self::RecvStream) { let cp = self.inner.clone(); - let id = self.id; - (self, H3Stream::new(cp, id)) + (self, H3Stream::new(cp)) } } diff --git a/crates/libs/msquic/src/stream.rs b/crates/libs/msquic/src/stream.rs index 5293369..679da42 100644 --- a/crates/libs/msquic/src/stream.rs +++ b/crates/libs/msquic/src/stream.rs @@ -364,4 +364,9 @@ impl QStream { } rx.await; } + + // get stream id + pub fn get_id(&self) -> u64 { + self.inner.inner.get_id() + } }