Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: rustls 0.23.1, next version of tokio-rustls #112

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,12 @@ readme = "README.md"
repository = "https://github.com/programatik29/axum-server"
version = "0.6.0"

[patch.crates-io]
tokio-rustls = { git = "https://github.com/rustls/tokio-rustls", rev = "3a153acec6c4d189eb5de501b2155b4484b8651b" } # main

[features]
default = []
tls-rustls = ["arc-swap", "rustls", "rustls-pemfile", "tokio/fs", "tokio/time", "tokio-rustls"]
tls-rustls = ["arc-swap", "rustls", "rustls-pemfile", "tokio/fs", "tokio/time", "dep:tokio-rustls"]
tls-openssl = ["arc-swap", "openssl", "tokio-openssl"]

[dependencies]
Expand All @@ -34,17 +37,17 @@ tower = { version = "0.4", features = ["util"] }
# optional dependencies
## rustls
arc-swap = { version = "1", optional = true }
rustls = { version = "0.21", features = ["dangerous_configuration"], optional = true }
rustls-pemfile = { version = "2.0.0", optional = true }
tokio-rustls = { version = "0.24", optional = true }
rustls = { version = "0.23.1", default-features = false, optional = true }
rustls-pemfile = { version = "2.1.0", default-features = false, optional = true }
tokio-rustls = { version = "0.25", default-features = false, features = ["ring"], optional = true }

## openssl
openssl = { version = "0.10", optional = true }
tokio-openssl = { version = "0.6", optional = true }

[dev-dependencies]
serial_test = "2.0"
axum = "0.7"
axum = "0.7.1"
hyper = { version = "1.0.1", features = ["full"] }
tokio = { version = "1", features = ["full"] }
tower = { version = "0.4", features = ["util"] }
Expand Down
125 changes: 63 additions & 62 deletions src/tls_rustls/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@ use crate::{
server::{io_other, Server},
};
use arc_swap::ArcSwap;
use rustls::{Certificate, PrivateKey, ServerConfig};
use rustls_pemfile::Item;
use rustls::{
pki_types::{CertificateDer, PrivateKeyDer},
ServerConfig,
};
use std::time::Duration;
use std::{fmt, io, net::SocketAddr, path::Path, sync::Arc};
use tokio::{
Expand Down Expand Up @@ -172,10 +174,8 @@ impl RustlsConfig {
/// The certificate must be DER-encoded X.509.
///
/// The private key must be DER-encoded ASN.1 in either PKCS#8 or PKCS#1 format.
pub async fn from_der(cert: Vec<Vec<u8>>, key: Vec<u8>) -> io::Result<Self> {
let server_config = spawn_blocking(|| config_from_der(cert, key))
.await
.unwrap()?;
pub async fn from_der(cert: Vec<Vec<u8>>, key: PrivateKeyDer<'static>) -> io::Result<Self> {
let server_config = config_from_der(cert, key)?;
let inner = Arc::new(ArcSwap::from_pointee(server_config));

Ok(Self { inner })
Expand Down Expand Up @@ -218,10 +218,12 @@ impl RustlsConfig {
/// The certificate must be DER-encoded X.509.
///
/// The private key must be DER-encoded ASN.1 in either PKCS#8 or PKCS#1 format.
pub async fn reload_from_der(&self, cert: Vec<Vec<u8>>, key: Vec<u8>) -> io::Result<()> {
let server_config = spawn_blocking(|| config_from_der(cert, key))
.await
.unwrap()?;
pub async fn reload_from_der(
&self,
cert: Vec<Vec<u8>>,
key: PrivateKeyDer<'static>,
) -> io::Result<()> {
let server_config = config_from_der(cert, key)?;
let inner = Arc::new(server_config);

self.inner.store(inner);
Expand Down Expand Up @@ -278,12 +280,10 @@ impl fmt::Debug for RustlsConfig {
}
}

fn config_from_der(cert: Vec<Vec<u8>>, key: Vec<u8>) -> io::Result<ServerConfig> {
let cert = cert.into_iter().map(Certificate).collect();
let key = PrivateKey(key);
fn config_from_der(cert: Vec<Vec<u8>>, key: PrivateKeyDer<'static>) -> io::Result<ServerConfig> {
let cert = cert.into_iter().map(CertificateDer::from).collect();

let mut config = ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(cert, key)
.map_err(io_other)?;
Expand All @@ -295,24 +295,13 @@ fn config_from_der(cert: Vec<Vec<u8>>, key: Vec<u8>) -> io::Result<ServerConfig>

fn config_from_pem(cert: Vec<u8>, key: Vec<u8>) -> io::Result<ServerConfig> {
let cert = rustls_pemfile::certs(&mut cert.as_ref())
.map(|it| it.map(|it| it.to_vec()))
.map(|cert| cert.map(|cert| cert.as_ref().to_vec()))
.collect::<Result<Vec<_>, _>>()?;
// Check the entire PEM file for the key in case it is not first section
let mut key_vec: Vec<Vec<u8>> = rustls_pemfile::read_all(&mut key.as_ref())
.filter_map(|i| match i.ok()? {
Item::Sec1Key(key) => Some(key.secret_sec1_der().to_vec()),
Item::Pkcs1Key(key) => Some(key.secret_pkcs1_der().to_vec().into()),
Item::Pkcs8Key(key) => Some(key.secret_pkcs8_der().to_vec().into()),
_ => None,
})
.collect();

// Make sure file contains only one key
if key_vec.len() != 1 {
return Err(io_other("private key format not supported"));
}
// Use the first private key found.
let key = rustls_pemfile::private_key(&mut key.as_ref())?
.ok_or(io_other("private key format not found"))?;

config_from_der(cert, key_vec.pop().unwrap())
config_from_der(cert, key)
}

async fn config_from_pem_file(
Expand All @@ -330,21 +319,12 @@ async fn config_from_pem_chain_file(
chain: impl AsRef<Path>,
) -> io::Result<ServerConfig> {
let cert = tokio::fs::read(cert.as_ref()).await?;
let cert = rustls_pemfile::certs(&mut cert.as_ref())
.map(|it| it.map(|it| rustls::Certificate(it.to_vec())))
.collect::<Result<Vec<_>, _>>()?;
let cert = rustls_pemfile::certs(&mut cert.as_ref()).collect::<Result<Vec<_>, _>>()?;
let key = tokio::fs::read(chain.as_ref()).await?;
let key_cert: rustls::PrivateKey = match rustls_pemfile::read_one(&mut key.as_ref())?
.ok_or_else(|| io_other("could not parse pem file"))?
{
Item::Pkcs8Key(key) => Ok(rustls::PrivateKey(key.secret_pkcs8_der().to_vec().into())),
x => Err(io_other(format!(
"invalid certificate format, received: {x:?}"
))),
}?;
let key_cert = rustls_pemfile::private_key(&mut key.as_ref())?
.ok_or_else(|| io_other("could not parse pem file"))?;

ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(cert, key_cert)
.map_err(|_| io_other("invalid certificate"))
Expand All @@ -362,17 +342,10 @@ mod tests {
use http_body_util::BodyExt;
use hyper::client::conn::http1::{handshake, SendRequest};
use hyper_util::rt::TokioIo;
use rustls::{
client::{ServerCertVerified, ServerCertVerifier},
Certificate, ClientConfig, ServerName,
};
use std::{
convert::TryFrom,
io,
net::SocketAddr,
sync::Arc,
time::{Duration, SystemTime},
};
use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
use rustls::pki_types::{CertificateDer, ServerName};
use rustls::{ClientConfig, SignatureScheme};
use std::{io, net::SocketAddr, sync::Arc, time::Duration};
use tokio::time::sleep;
use tokio::{net::TcpStream, task::JoinHandle, time::timeout};
use tokio_rustls::TlsConnector;
Expand Down Expand Up @@ -552,13 +525,15 @@ mod tests {
(handle, server_task, addr)
}

async fn get_first_cert(addr: SocketAddr) -> Certificate {
async fn get_first_cert(addr: SocketAddr) -> CertificateDer<'static> {
let stream = TcpStream::connect(addr).await.unwrap();
let tls_stream = tls_connector().connect(dns_name(), stream).await.unwrap();

let (_io, client_connection) = tls_stream.into_inner();

client_connection.peer_certificates().unwrap()[0].clone()
client_connection.peer_certificates().unwrap()[0]
.clone()
.into_owned()
}

async fn connect(addr: SocketAddr) -> (SendRequest<Body>, JoinHandle<()>) {
Expand Down Expand Up @@ -586,24 +561,50 @@ mod tests {
}

fn tls_connector() -> TlsConnector {
#[derive(Debug)]
struct NoVerify;

impl ServerCertVerifier for NoVerify {
fn verify_server_cert(
&self,
_end_entity: &Certificate,
_intermediates: &[Certificate],
_server_name: &ServerName,
_scts: &mut dyn Iterator<Item = &[u8]>,
_end_entity: &CertificateDer<'_>,
_intermediates: &[CertificateDer<'_>],
_server_name: &ServerName<'_>,
_ocsp_response: &[u8],
_now: SystemTime,
_now: rustls::pki_types::UnixTime,
) -> Result<ServerCertVerified, rustls::Error> {
Ok(ServerCertVerified::assertion())
}

fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, rustls::Error> {
Ok(HandshakeSignatureValid::assertion())
}

fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, rustls::Error> {
Ok(HandshakeSignatureValid::assertion())
}

fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
vec![
SignatureScheme::RSA_PKCS1_SHA256,
SignatureScheme::RSA_PSS_SHA256,
SignatureScheme::ECDSA_NISTP256_SHA256,
]
}
}

let mut client_config = ClientConfig::builder()
.with_safe_defaults()
.dangerous()
.with_custom_certificate_verifier(Arc::new(NoVerify))
.with_no_client_auth();

Expand All @@ -612,7 +613,7 @@ mod tests {
TlsConnector::from(Arc::new(client_config))
}

fn dns_name() -> ServerName {
fn dns_name() -> ServerName<'static> {
ServerName::try_from("localhost").unwrap()
}
}