Skip to content

Commit

Permalink
feat: bump rustls to 0.22 for upgrading to hyper-v1 (#2091)
Browse files Browse the repository at this point in the history
  • Loading branch information
loispostula authored Jan 15, 2024
1 parent 7f4bfcb commit e1589ee
Show file tree
Hide file tree
Showing 8 changed files with 161 additions and 115 deletions.
13 changes: 7 additions & 6 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ __tls = []

# Enables common rustls code.
# Equivalent to rustls-tls-manual-roots but shorter :)
__rustls = ["hyper-rustls", "tokio-rustls", "rustls", "__tls", "rustls-pemfile"]
__rustls = ["hyper-rustls", "tokio-rustls", "rustls", "__tls", "rustls-pemfile", "rustls-pki-types"]

# When enabled, disable using the cached SYS_PROXIES.
__internal_proxy_sys_no_cache = []
Expand Down Expand Up @@ -120,10 +120,11 @@ native-tls-crate = { version = "0.2.10", optional = true, package = "native-tls"
tokio-native-tls = { version = "0.3.0", optional = true }

# rustls-tls
hyper-rustls = { version = "0.24.0", default-features = false, optional = true }
rustls = { version = "0.21.6", features = ["dangerous_configuration"], optional = true }
tokio-rustls = { version = "0.24", optional = true }
webpki-roots = { version = "0.25", optional = true }
hyper-rustls = { version = "0.26.0", default-features = false, optional = true }
rustls = { version = "0.22.2", optional = true }
rustls-pki-types = { version = "1.1.0", features = ["alloc"] ,optional = true }
tokio-rustls = { version = "0.25", optional = true }
webpki-roots = { version = "0.26.0", optional = true }
rustls-native-certs = { version = "0.6", optional = true }
rustls-pemfile = { version = "1.0", optional = true }

Expand All @@ -150,7 +151,7 @@ futures-channel = { version="0.3", optional = true}

[target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies]
env_logger = "0.8"
hyper = { version = "0.14", default-features = false, features = ["tcp", "stream", "http1", "http2", "client", "server", "runtime"] }
hyper = { version = "1.1.0", default-features = false, features = ["http1", "http2", "client", "server"] }
serde = { version = "1.0", features = ["derive"] }
libflate = "1.0"
brotli_crate = { package = "brotli", version = "3.3.0" }
Expand Down
28 changes: 5 additions & 23 deletions src/async_impl/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@ use http::header::{
};
use http::uri::Scheme;
use http::Uri;
use hyper_util::client::legacy::{
connect::HttpConnector, /*, ResponseFuture as HyperResponseFuture*/
};
use hyper_util::client::legacy::connect::HttpConnector;
#[cfg(feature = "native-tls-crate")]
use native_tls_crate::TlsConnector;
use pin_project_lite::pin_project;
Expand Down Expand Up @@ -468,18 +466,7 @@ impl ClientBuilder {

#[cfg(feature = "rustls-tls-webpki-roots")]
if config.tls_built_in_root_certs {
use rustls::OwnedTrustAnchor;

let trust_anchors =
webpki_roots::TLS_SERVER_ROOTS.iter().map(|trust_anchor| {
OwnedTrustAnchor::from_subject_spki_name_constraints(
trust_anchor.subject,
trust_anchor.spki,
trust_anchor.name_constraints,
)
});

root_cert_store.add_trust_anchors(trust_anchors);
root_cert_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
}

#[cfg(feature = "rustls-tls-native-roots")]
Expand All @@ -489,11 +476,10 @@ impl ClientBuilder {
for cert in rustls_native_certs::load_native_certs()
.map_err(crate::error::builder)?
{
let cert = rustls::Certificate(cert.0);
// Continue on parsing errors, as native stores often include ancient or syntactically
// invalid certificates, like root certificates without any X509 extensions.
// Inspiration: https://github.com/rustls/rustls/blob/633bf4ba9d9521a95f68766d04c22e2b01e68318/rustls/src/anchors.rs#L105-L112
match root_cert_store.add(&cert) {
match root_cert_store.add(cert.into()) {
Ok(_) => valid_count += 1,
Err(err) => {
invalid_count += 1;
Expand Down Expand Up @@ -536,12 +522,8 @@ impl ClientBuilder {
}

// Build TLS config
let config_builder = rustls::ClientConfig::builder()
.with_safe_default_cipher_suites()
.with_safe_default_kx_groups()
.with_protocol_versions(&versions)
.map_err(crate::error::builder)?
.with_root_certificates(root_cert_store);
let config_builder =
rustls::ClientConfig::builder().with_root_certificates(root_cert_store);

// Finalize TLS config
let mut tls = if let Some(id) = config.identity {
Expand Down
2 changes: 1 addition & 1 deletion src/async_impl/h3_client/connect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use bytes::Bytes;
use h3::client::SendRequest;
use h3_quinn::{Connection, OpenStreams};
use http::Uri;
use hyper::client::connect::dns::Name;
use hyper_util::client::legacy::connect::dns::Name;
use quinn::{ClientConfig, Endpoint, TransportConfig};
use std::net::{IpAddr, SocketAddr};
use std::str::FromStr;
Expand Down
2 changes: 1 addition & 1 deletion src/async_impl/h3_client/dns.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use core::task;
use hyper::client::connect::dns::Name;
use hyper_util::client::legacy::connect::dns::Name;
use std::future::Future;
use std::net::SocketAddr;
use std::task::Poll;
Expand Down
2 changes: 1 addition & 1 deletion src/async_impl/h3_client/pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use h3::client::SendRequest;
use h3_quinn::{Connection, OpenStreams};
use http::uri::{Authority, Scheme};
use http::{Request, Response, Uri};
use hyper::Body as HyperBody;
use hyper::body as HyperBody;
use log::trace;

pub(super) type Key = (Scheme, Authority);
Expand Down
140 changes: 79 additions & 61 deletions src/connect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ use hyper_util::rt::TokioIo;
use native_tls_crate::{TlsConnector, TlsConnectorBuilder};
use tower_service::Service;

use tokio::io::{AsyncReadExt, AsyncWriteExt};

use pin_project_lite::pin_project;
use std::future::Future;
use std::io::{self, IoSlice};
Expand Down Expand Up @@ -212,8 +214,9 @@ impl Connector {
let tls = tls_proxy.clone();
let host = dst.host().ok_or("no host in url")?.to_string();
let conn = socks::connect(proxy, dst, dns).await?;
let server_name = rustls::ServerName::try_from(host.as_str())
.map_err(|_| "Invalid Server Name")?;
let server_name =
rustls_pki_types::ServerName::try_from(host.as_str().to_owned())
.map_err(|_| "Invalid Server Name")?;
let io = RustlsConnector::from(tls)
.connect(server_name, conn)
.await?;
Expand Down Expand Up @@ -301,8 +304,8 @@ impl Connector {

if let hyper_rustls::MaybeHttpsStream::Https(stream) = io {
if !self.nodelay {
let (io, _) = stream.get_ref();
io.set_nodelay(false)?;
let (io, _) = stream.inner().get_ref();
io.inner().inner().set_nodelay(false)?;
}
Ok(Conn {
inner: self.verbose.wrap(RustlsTlsConn { inner: stream }),
Expand Down Expand Up @@ -376,7 +379,7 @@ impl Connector {
tls_proxy,
} => {
if dst.scheme() == Some(&Scheme::HTTPS) {
use rustls::ServerName;
use rustls_pki_types::ServerName;
use std::convert::TryFrom;
use tokio_rustls::TlsConnector as RustlsConnector;

Expand All @@ -387,16 +390,18 @@ impl Connector {
let tls = tls.clone();
let conn = http.call(proxy_dst).await?;
log::trace!("tunneling HTTPS over proxy");
let maybe_server_name =
ServerName::try_from(host.as_str()).map_err(|_| "Invalid Server Name");
let maybe_server_name = ServerName::try_from(host.as_str().to_owned())
.map_err(|_| "Invalid Server Name");
let tunneled = tunnel(conn, host, port, self.user_agent.clone(), auth).await?;
let server_name = maybe_server_name?;
let io = RustlsConnector::from(tls)
.connect(server_name, tunneled)
.connect(server_name, TokioIo::new(tunneled))
.await?;

return Ok(Conn {
inner: self.verbose.wrap(RustlsTlsConn { inner: io }),
inner: self.verbose.wrap(RustlsTlsConn {
inner: TokioIo::new(io),
}),
is_proxy: false,
tls_info: false,
});
Expand Down Expand Up @@ -533,54 +538,47 @@ impl TlsInfoFactory for hyper_tls::MaybeHttpsStream<TokioIo<tokio::net::TcpStrea
}
}

/*
#[cfg(feature = "__tls")]
impl<T: TlsInfoFactory> TlsInfoFactory for TokioIo<T> {
fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
self.inner().tls_info()
}
}
#[cfg(feature = "default-tls")]
impl<T: TlsInfoFactory> TlsInfoFactory for hyper_tls::MaybeHttpsStream<T> {
fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
match self {
hyper_tls::MaybeHttpsStream::Https(tls) => tls.tls_info(),
hyper_tls::MaybeHttpsStream::Http(_) => None,
}
}
}
#[cfg(feature = "default-tls")]
impl<T> TlsInfoFactory for hyper_tls::TlsStream<T>
where
T: tokio::io::AsyncRead + tokio::io::AsyncWrite,
{
#[cfg(feature = "__rustls")]
impl TlsInfoFactory for tokio_rustls::client::TlsStream<TokioIo<TokioIo<tokio::net::TcpStream>>> {
fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
let peer_certificate = self
.get_ref()
.peer_certificate()
.ok()
.flatten()
.and_then(|c| c.to_der().ok());
.1
.peer_certificates()
.and_then(|certs| certs.first())
.map(|c| c.first())
.and_then(|c| c.map(|cc| vec![*cc]));
Some(crate::tls::TlsInfo { peer_certificate })
}
}

#[cfg(feature = "__rustls")]
impl<T> TlsInfoFactory for tokio_rustls::TlsStream<T> {
impl TlsInfoFactory
for tokio_rustls::client::TlsStream<
TokioIo<hyper_rustls::MaybeHttpsStream<TokioIo<tokio::net::TcpStream>>>,
>
{
fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
let peer_certificate = self
.get_ref()
.1
.peer_certificates()
.and_then(|certs| certs.first())
.map(|c| c.0.clone());
.map(|c| c.first())
.and_then(|c| c.map(|cc| vec![*cc]));
Some(crate::tls::TlsInfo { peer_certificate })
}
}
*/

#[cfg(feature = "__rustls")]
impl TlsInfoFactory for hyper_rustls::MaybeHttpsStream<TokioIo<tokio::net::TcpStream>> {
fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
match self {
hyper_rustls::MaybeHttpsStream::Https(tls) => tls.tls_info(),
hyper_rustls::MaybeHttpsStream::Http(_) => None,
}
}
}

pub(crate) trait AsyncConn:
Read + Write + Connection + Send + Sync + Unpin + 'static
Expand Down Expand Up @@ -692,7 +690,6 @@ where
T: Read + Write + Unpin,
{
use hyper_util::rt::TokioIo;
use tokio::io::{AsyncReadExt, AsyncWriteExt};

let mut buf = format!(
"\
Expand Down Expand Up @@ -793,7 +790,7 @@ mod native_tls_conn {
}
}

impl Connection for NativeTlsConn<TokioIo<MaybeHttpsStream<TokioIo<tokio::net::TcpStream>>>> {
impl Connection for NativeTlsConn<TokioIo<MaybeHttpsStream<TokioIo<TcpStream>>>> {
fn connected(&self) -> Connected {
self.inner
.inner()
Expand Down Expand Up @@ -870,32 +867,57 @@ mod native_tls_conn {
mod rustls_tls_conn {
use super::TlsInfoFactory;
use hyper::rt::{Read, ReadBufCursor, Write};
use hyper_rustls::MaybeHttpsStream;
use hyper_util::client::legacy::connect::{Connected, Connection};
use hyper_util::rt::TokioIo;
use pin_project_lite::pin_project;
use std::{
io::{self, IoSlice},
pin::Pin,
task::{Context, Poll},
};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream;
use tokio_rustls::client::TlsStream;

pin_project! {
pub(super) struct RustlsTlsConn<T> {
#[pin] pub(super) inner: TlsStream<T>,
#[pin] pub(super) inner: TokioIo<TlsStream<T>>,
}
}

impl<T: Connection + Read + Write + Unpin> Connection for RustlsTlsConn<T> {
impl Connection for RustlsTlsConn<TokioIo<TokioIo<TcpStream>>> {
fn connected(&self) -> Connected {
if self.inner.inner().get_ref().1.alpn_protocol() == Some(b"h2") {
self.inner
.inner()
.get_ref()
.0
.inner()
.connected()
.negotiated_h2()
} else {
self.inner.inner().get_ref().0.inner().connected()
}
}
}
impl Connection for RustlsTlsConn<TokioIo<MaybeHttpsStream<TokioIo<TcpStream>>>> {
fn connected(&self) -> Connected {
if self.inner.get_ref().1.alpn_protocol() == Some(b"h2") {
self.inner.get_ref().0.connected().negotiated_h2()
if self.inner.inner().get_ref().1.alpn_protocol() == Some(b"h2") {
self.inner
.inner()
.get_ref()
.0
.inner()
.connected()
.negotiated_h2()
} else {
self.inner.get_ref().0.connected()
self.inner.inner().get_ref().0.inner().connected()
}
}
}

impl<T: Read + Write + Unpin> Read for RustlsTlsConn<T> {
impl<T: AsyncRead + AsyncWrite + Unpin> Read for RustlsTlsConn<T> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context,
Expand All @@ -906,14 +928,14 @@ mod rustls_tls_conn {
}
}

impl<T: Read + Write + Unpin> Write for RustlsTlsConn<T> {
impl<T: AsyncRead + AsyncWrite + Unpin> Write for RustlsTlsConn<T> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &[u8],
) -> Poll<Result<usize, tokio::io::Error>> {
let this = self.project();
AsyncWrite::poll_write(this.inner, cx, buf)
Write::poll_write(this.inner, cx, buf)
}

fn poll_write_vectored(
Expand All @@ -922,7 +944,7 @@ mod rustls_tls_conn {
bufs: &[IoSlice<'_>],
) -> Poll<Result<usize, io::Error>> {
let this = self.project();
AsyncWrite::poll_write_vectored(this.inner, cx, bufs)
Write::poll_write_vectored(this.inner, cx, bufs)
}

fn is_write_vectored(&self) -> bool {
Expand All @@ -934,25 +956,21 @@ mod rustls_tls_conn {
cx: &mut Context,
) -> Poll<Result<(), tokio::io::Error>> {
let this = self.project();
AsyncWrite::poll_flush(this.inner, cx)
Write::poll_flush(this.inner, cx)
}

fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context,
) -> Poll<Result<(), tokio::io::Error>> {
let this = self.project();
AsyncWrite::poll_shutdown(this.inner, cx)
}
}

impl TlsInfoFactory for RustlsTlsConn<tokio::net::TcpStream> {
fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
self.inner.tls_info()
Write::poll_shutdown(this.inner, cx)
}
}

impl TlsInfoFactory for RustlsTlsConn<hyper_rustls::MaybeHttpsStream<tokio::net::TcpStream>> {
impl<T> TlsInfoFactory for RustlsTlsConn<T>
where
TokioIo<TlsStream<T>>: TlsInfoFactory,
{
fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
self.inner.tls_info()
}
Expand Down
Loading

0 comments on commit e1589ee

Please sign in to comment.