From d6026c4d23a86dc545071c64a052613d2e412b49 Mon Sep 17 00:00:00 2001 From: Kristof Mattei <864376+Kristof-Mattei@users.noreply.github.com> Date: Sun, 29 Oct 2023 14:26:02 -0700 Subject: [PATCH] feat: upgrade to hyper-v1, use hyper-utils for now --- Cargo.toml | 11 ++++++++-- examples/client.rs | 19 ++++++++++++----- src/client.rs | 53 ++++++++++++++++++++++++++++++---------------- src/stream.rs | 32 +++++++++++++++++++--------- 4 files changed, 80 insertions(+), 35 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index c6ddbd9..d513b5f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,10 +16,17 @@ vendored = ["native-tls/vendored"] [dependencies] bytes = "1" native-tls = "0.2.1" -hyper = { version = "0.14.2", default-features = false, features = ["tcp", "client"] } +hyper = { version = "1.0.0-rc.4", default-features = false } +hyper-util = { git = "https://github.com/hyperium/hyper-util", default-features = false, features = [ + "client", +], rev = "ced9f812460420017705fa7cae4dca7be9e23f4a" } tokio = "1" tokio-native-tls = "0.3" +tower-service = "0.3" +http-body-util = "0.1.0-rc.3" [dev-dependencies] tokio = { version = "1.0.0", features = ["io-std", "macros", "io-util"] } -hyper = { version = "0.14.2", default-features = false, features = ["http1"] } +hyper-util = { git = "https://github.com/hyperium/hyper-util", default-features = false, features = [ + "http1", +], rev = "ced9f812460420017705fa7cae4dca7be9e23f4a" } diff --git a/examples/client.rs b/examples/client.rs index 005a333..e74f2fe 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -1,20 +1,29 @@ -use hyper::{body::HttpBody as _, Client}; +use bytes::Bytes; +use http_body_util::BodyExt; + +use http_body_util::Empty; use hyper_tls::HttpsConnector; +use hyper_util::{client::legacy::Client, rt::TokioExecutor}; use tokio::io::{self, AsyncWriteExt as _}; #[tokio::main(flavor = "current_thread")] async fn main() -> Result<(), Box> { let https = HttpsConnector::new(); - let client = Client::builder().build::<_, hyper::Body>(https); + + let client = Client::builder(TokioExecutor::new()).build::<_, Empty>(https); let mut res = client.get("https://hyper.rs".parse()?).await?; println!("Status: {}", res.status()); println!("Headers:\n{:#?}", res.headers()); - while let Some(chunk) = res.body_mut().data().await { - let chunk = chunk?; - io::stdout().write_all(&chunk).await? + while let Some(frame) = res.body_mut().frame().await { + let frame = frame?; + + if let Some(d) = frame.data_ref() { + io::stdout().write_all(d).await?; + } } + Ok(()) } diff --git a/src/client.rs b/src/client.rs index fee1c99..34b578f 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,11 +1,14 @@ +use hyper::{ + rt::{Read, Write}, + Uri, +}; +use hyper_util::{client::connect::HttpConnector, rt::TokioIo}; use std::fmt; use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; - -use hyper::{client::connect::HttpConnector, service::Service, Uri}; -use tokio::io::{AsyncRead, AsyncWrite}; use tokio_native_tls::TlsConnector; +use tower_service::Service; use crate::stream::MaybeHttpsStream; @@ -20,7 +23,7 @@ pub struct HttpsConnector { } impl HttpsConnector { - /// Construct a new HttpsConnector. + /// Construct a new `HttpsConnector`. /// /// This uses hyper's default `HttpConnector`, and default `TlsConnector`. /// If you wish to use something besides the defaults, use `From::from`. @@ -28,9 +31,9 @@ impl HttpsConnector { /// # Note /// /// By default this connector will use plain HTTP if the URL provided uses - /// the HTTP scheme (eg: http://example.com/). + /// the HTTP scheme (eg: ). /// - /// If you would like to force the use of HTTPS then call https_only(true) + /// If you would like to force the use of HTTPS then call `https_only(true)` /// on the returned connector. /// /// # Panics @@ -39,10 +42,12 @@ impl HttpsConnector { /// /// To handle that error yourself, you can use the `HttpsConnector::from` /// constructor after trying to make a `TlsConnector`. + #[must_use] pub fn new() -> Self { - native_tls::TlsConnector::new() - .map(|tls| HttpsConnector::new_(tls.into())) - .unwrap_or_else(|e| panic!("HttpsConnector::new() failure: {}", e)) + native_tls::TlsConnector::new().map_or_else( + |e| panic!("HttpsConnector::new() failure: {}", e), + |tls| HttpsConnector::new_(tls.into()), + ) } fn new_(tls: TlsConnector) -> Self { @@ -68,15 +73,22 @@ impl HttpsConnector { /// With connector constructor /// + /// # Panics + /// + /// This will panic if the underlying TLS context could not be created. + /// + /// To handle that error yourself, you can use the `HttpsConnector::from` + /// constructor after trying to make a `TlsConnector`. pub fn new_with_connector(http: T) -> Self { - native_tls::TlsConnector::new() - .map(|tls| HttpsConnector::from((http, tls.into()))) - .unwrap_or_else(|e| { + native_tls::TlsConnector::new().map_or_else( + |e| { panic!( "HttpsConnector::new_with_connector() failure: {}", e ) - }) + }, + |tls| HttpsConnector::from((http, tls.into())), + ) } } @@ -95,14 +107,14 @@ impl fmt::Debug for HttpsConnector { f.debug_struct("HttpsConnector") .field("force_https", &self.force_https) .field("http", &self.http) - .finish() + .finish_non_exhaustive() } } impl Service for HttpsConnector where T: Service, - T::Response: AsyncRead + AsyncWrite + Send + Unpin, + T::Response: Read + Write + Send + Unpin, T::Future: Send + 'static, T::Error: Into, { @@ -131,11 +143,16 @@ where .trim_matches(|c| c == '[' || c == ']') .to_owned(); let connecting = self.http.call(dst); - let tls = self.tls.clone(); + + let tls_connector = self.tls.clone(); + let fut = async move { let tcp = connecting.await.map_err(Into::into)?; + let maybe = if is_https { - let tls = tls.connect(&host, tcp).await?; + let stream = TokioIo::new(tcp); + + let tls = TokioIo::new(tls_connector.connect(&host, stream).await?); MaybeHttpsStream::Https(tls) } else { MaybeHttpsStream::Http(tcp) @@ -155,7 +172,7 @@ type BoxedFut = Pin, BoxEr /// A Future representing work to connect to a URL, and a TLS handshake. pub struct HttpsConnecting(BoxedFut); -impl Future for HttpsConnecting { +impl Future for HttpsConnecting { type Output = Result, BoxError>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { diff --git a/src/stream.rs b/src/stream.rs index 4875410..fbd5097 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -4,8 +4,12 @@ use std::io::IoSlice; use std::pin::Pin; use std::task::{Context, Poll}; -use hyper::client::connect::{Connected, Connection}; -use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use hyper::rt::{Read, ReadBufCursor, Write}; + +use hyper_util::{ + client::connect::{Connected, Connection}, + rt::TokioIo, +}; pub use tokio_native_tls::TlsStream; /// A stream that might be protected with TLS. @@ -13,7 +17,7 @@ pub enum MaybeHttpsStream { /// A stream over plain text. Http(T), /// A stream protected with TLS. - Https(TlsStream), + Https(TokioIo>>), } // ===== impl MaybeHttpsStream ===== @@ -33,18 +37,24 @@ impl From for MaybeHttpsStream { } } -impl From> for MaybeHttpsStream { - fn from(inner: TlsStream) -> Self { +impl From>> for MaybeHttpsStream { + fn from(inner: TlsStream>) -> Self { + MaybeHttpsStream::Https(TokioIo::new(inner)) + } +} + +impl From>>> for MaybeHttpsStream { + fn from(inner: TokioIo>>) -> Self { MaybeHttpsStream::Https(inner) } } -impl AsyncRead for MaybeHttpsStream { +impl Read for MaybeHttpsStream { #[inline] fn poll_read( self: Pin<&mut Self>, cx: &mut Context, - buf: &mut ReadBuf, + buf: ReadBufCursor<'_>, ) -> Poll> { match Pin::get_mut(self) { MaybeHttpsStream::Http(s) => Pin::new(s).poll_read(cx, buf), @@ -53,7 +63,7 @@ impl AsyncRead for MaybeHttpsStream { } } -impl AsyncWrite for MaybeHttpsStream { +impl Write for MaybeHttpsStream { #[inline] fn poll_write( self: Pin<&mut Self>, @@ -101,11 +111,13 @@ impl AsyncWrite for MaybeHttpsStream { } } -impl Connection for MaybeHttpsStream { +impl Connection for MaybeHttpsStream { fn connected(&self) -> Connected { match self { MaybeHttpsStream::Http(s) => s.connected(), - MaybeHttpsStream::Https(s) => s.get_ref().get_ref().get_ref().connected(), + MaybeHttpsStream::Https(s) => { + s.inner().get_ref().get_ref().get_ref().inner().connected() + } } } }