diff --git a/Cargo.toml b/Cargo.toml index abf9a8249..c8b82ea00 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -161,8 +161,8 @@ tokio-socks = { version = "0.5.2", optional = true } hickory-resolver = { version = "0.24", optional = true, features = ["tokio-runtime"] } # HTTP/3 experimental support -h3 = { version = "0.0.6", optional = true } -h3-quinn = { version = "0.0.7", optional = true } +h3 = { version = "0.0.6", git = "https://github.com/hyperium/h3.git", branch = "master", optional = true } +h3-quinn = { version = "0.0.7", git = "https://github.com/hyperium/h3.git", branch = "master", optional = true } quinn = { version = "0.11.1", default-features = false, features = ["rustls", "runtime-tokio"], optional = true } slab = { version = "0.4.9", optional = true } # just to get minimal versions working with quinn futures-channel = { version = "0.3", optional = true } @@ -255,6 +255,11 @@ path = "examples/form.rs" name = "simple" path = "examples/simple.rs" +[[example]] +name = "h3_simple" +path = "examples/h3_simple.rs" +required-features = ["http3", "rustls-tls"] + [[example]] name = "connect_via_lower_priority_tokio_runtime" path = "examples/connect_via_lower_priority_tokio_runtime.rs" diff --git a/examples/h3_simple.rs b/examples/h3_simple.rs index dcca7a2fa..53a2379e5 100644 --- a/examples/h3_simple.rs +++ b/examples/h3_simple.rs @@ -7,18 +7,7 @@ #[cfg(not(target_arch = "wasm32"))] #[tokio::main] async fn main() -> Result<(), reqwest::Error> { - use http::Version; - use reqwest::{Client, IntoUrl, Response}; - - async fn get(url: T) -> reqwest::Result { - Client::builder() - .http3_prior_knowledge() - .build()? - .get(url) - .version(Version::HTTP_3) - .send() - .await - } + let client = reqwest::Client::builder().http3_prior_knowledge().build()?; // Some simple CLI args requirements... let url = match std::env::args().nth(1) { @@ -31,7 +20,11 @@ async fn main() -> Result<(), reqwest::Error> { eprintln!("Fetching {url:?}..."); - let res = get(url).await?; + let res = client + .get(url) + .version(http::Version::HTTP_3) + .send() + .await?; eprintln!("Response: {:?} {}", res.version(), res.status()); eprintln!("Headers: {:#?}\n", res.headers()); diff --git a/src/async_impl/h3_client/pool.rs b/src/async_impl/h3_client/pool.rs index 100a0935d..9529f16ff 100644 --- a/src/async_impl/h3_client/pool.rs +++ b/src/async_impl/h3_client/pool.rs @@ -1,7 +1,9 @@ use bytes::Bytes; use std::collections::{HashMap, HashSet}; +use std::pin::Pin; use std::sync::mpsc::{Receiver, TryRecvError}; use std::sync::{Arc, Mutex}; +use std::task::{Context, Poll}; use std::time::Duration; use tokio::time::Instant; @@ -126,7 +128,6 @@ impl PoolClient { &mut self, req: Request, ) -> Result, BoxError> { - use http_body_util::{BodyExt, Full}; use hyper::body::Body as _; let (head, req_body) = req.into_parts(); @@ -152,14 +153,7 @@ impl PoolClient { let resp = stream.recv_response().await?; - let mut resp_body = Vec::new(); - while let Some(chunk) = stream.recv_data().await? { - resp_body.extend(chunk.chunk()) - } - - let resp_body = Full::new(resp_body.into()) - .map_err(|never| match never {}) - .boxed(); + let resp_body = crate::async_impl::body::boxed(Incoming::new(stream, resp.headers())); Ok(resp.map(|_| resp_body)) } @@ -195,6 +189,52 @@ impl PoolConnection { } } +struct Incoming { + inner: h3::client::RequestStream, + content_length: Option, +} + +impl Incoming { + fn new(stream: h3::client::RequestStream, headers: &http::header::HeaderMap) -> Self { + Self { + inner: stream, + content_length: headers + .get(http::header::CONTENT_LENGTH) + .and_then(|h| h.to_str().ok()) + .and_then(|v| v.parse().ok()), + } + } +} + +impl http_body::Body for Incoming +where + S: h3::quic::RecvStream, +{ + type Data = Bytes; + type Error = crate::error::Error; + + fn poll_frame( + mut self: Pin<&mut Self>, + cx: &mut Context, + ) -> Poll, Self::Error>>> { + match futures_core::ready!(self.inner.poll_recv_data(cx)) { + Ok(Some(mut b)) => Poll::Ready(Some(Ok(hyper::body::Frame::data( + b.copy_to_bytes(b.remaining()), + )))), + Ok(None) => Poll::Ready(None), + Err(e) => Poll::Ready(Some(Err(crate::error::body(e)))), + } + } + + fn size_hint(&self) -> hyper::body::SizeHint { + if let Some(content_length) = self.content_length { + hyper::body::SizeHint::with_exact(content_length) + } else { + hyper::body::SizeHint::default() + } + } +} + pub(crate) fn extract_domain(uri: &mut Uri) -> Result { let uri_clone = uri.clone(); match (uri_clone.scheme(), uri_clone.authority()) {