From 241401037eede61bec77a0f1ace585c0b23161c4 Mon Sep 17 00:00:00 2001 From: glendc Date: Thu, 19 Sep 2024 12:54:39 +0200 Subject: [PATCH] allow acceptor/connector data to be injected via ext --- rama-tls/src/boring/client/http.rs | 15 +++++++++----- rama-tls/src/boring/server/service.rs | 23 ++++++++++++--------- rama-tls/src/rustls/client/http.rs | 16 +++++++++------ rama-tls/src/rustls/server/service.rs | 29 ++++++++++++++++++--------- 4 files changed, 52 insertions(+), 31 deletions(-) diff --git a/rama-tls/src/boring/client/http.rs b/rama-tls/src/boring/client/http.rs index 8213fae8..39be3d12 100644 --- a/rama-tls/src/boring/client/http.rs +++ b/rama-tls/src/boring/client/http.rs @@ -233,7 +233,8 @@ where .map_err(|err| { OpaqueError::from_boxed(err.into()) .context("HttpsConnector(auto): compute transport context") - })?; + })? + .clone(); if !transport_ctx .app_protocol @@ -257,7 +258,8 @@ where let host = transport_ctx.authority.host().clone(); - let (stream, negotiated_params) = self.handshake(host, conn).await?; + let connector_data = ctx.get().cloned(); + let (stream, negotiated_params) = self.handshake(connector_data, host, conn).await?; tracing::trace!( authority = %transport_ctx.authority, @@ -313,7 +315,8 @@ where let host = transport_ctx.authority.host().clone(); - let (conn, negotiated_params) = self.handshake(host, conn).await?; + let connector_data = ctx.get().cloned(); + let (conn, negotiated_params) = self.handshake(connector_data, host, conn).await?; ctx.insert(negotiated_params); Ok(EstablishedClientConnection { @@ -363,7 +366,8 @@ where } }; - let (stream, negotiated_params) = self.handshake(host, conn).await?; + let connector_data = ctx.get().cloned(); + let (stream, negotiated_params) = self.handshake(connector_data, host, conn).await?; ctx.insert(negotiated_params); tracing::trace!("HttpsConnector(tunnel): connection secured"); @@ -381,13 +385,14 @@ where impl HttpsConnector { async fn handshake( &self, + connector_data: Option, server_host: Host, stream: T, ) -> Result<(SslStream, NegotiatedTlsParameters), BoxError> where T: Stream + Unpin, { - let (config, server_host) = match &self.connector_data { + let (config, server_host) = match connector_data.as_ref().or(self.connector_data.as_ref()) { Some(connector_data) => { let client_config = connector_data.connect_config_input.try_to_build_config()?; let server_host = connector_data.server_name().cloned().unwrap_or(server_host); diff --git a/rama-tls/src/boring/server/service.rs b/rama-tls/src/boring/server/service.rs index ae72dd27..59014c40 100644 --- a/rama-tls/src/boring/server/service.rs +++ b/rama-tls/src/boring/server/service.rs @@ -73,6 +73,11 @@ where type Error = BoxError; async fn serve(&self, mut ctx: Context, stream: IO) -> Result { + // allow tls acceptor data to be injected, + // e.g. useful for TLS environments where some data (such as server auth, think ACME) + // is updated at runtime, be it infrequent + let tls_config = &ctx.get::().unwrap_or(&self.data).config; + let mut acceptor_builder = SslAcceptor::mozilla_intermediate_v5(SslMethod::tls_server()) .context("create boring ssl acceptor")?; @@ -81,7 +86,7 @@ where .set_default_verify_paths() .context("build boring ssl acceptor: set default verify paths")?; - for (i, ca_cert) in self.data.config.cert_chain.iter().enumerate() { + for (i, ca_cert) in tls_config.cert_chain.iter().enumerate() { if i == 0 { acceptor_builder .set_certificate(ca_cert.as_ref()) @@ -93,13 +98,13 @@ where } } acceptor_builder - .set_private_key(self.data.config.private_key.as_ref()) + .set_private_key(tls_config.private_key.as_ref()) .context("build boring ssl acceptor: set private key")?; acceptor_builder .check_private_key() .context("build boring ssl acceptor: check private key")?; - if let Some(min_ver) = self.data.config.protocol_versions.iter().flatten().min() { + if let Some(min_ver) = tls_config.protocol_versions.iter().flatten().min() { acceptor_builder .set_min_proto_version(Some((*min_ver).try_into().map_err(|v| { OpaqueError::from_display(format!("protocol version {v}")) @@ -108,7 +113,7 @@ where .context("build boring ssl acceptor: set min proto version")?; } - if let Some(max_ver) = self.data.config.protocol_versions.iter().flatten().max() { + if let Some(max_ver) = tls_config.protocol_versions.iter().flatten().max() { acceptor_builder .set_max_proto_version(Some((*max_ver).try_into().map_err(|v| { OpaqueError::from_display(format!("protocol version {v}")) @@ -117,7 +122,7 @@ where .context("build boring ssl acceptor: set max proto version")?; } - for ca_cert in self.data.config.client_cert_chain.iter().flatten() { + for ca_cert in tls_config.client_cert_chain.iter().flatten() { acceptor_builder .add_client_ca(ca_cert) .context("build boring ssl acceptor: set ca client cert")?; @@ -142,16 +147,14 @@ where None }; - if !self - .data - .config + if !tls_config .alpn_protocols .as_ref() .map(|v| !v.is_empty()) .unwrap_or_default() { let mut buf = vec![]; - for alpn in self.data.config.alpn_protocols.iter().flatten() { + for alpn in tls_config.alpn_protocols.iter().flatten() { alpn.encode_wire_format(&mut buf) .context("build boring ssl acceptor: encode alpn")?; } @@ -160,7 +163,7 @@ where .context("build boring ssl acceptor: set alpn")?; } - if let Some(keylog_filename) = &self.data.config.keylog_filename { + if let Some(keylog_filename) = &tls_config.keylog_filename { // open file in append mode and write keylog to it with callback let file = std::fs::OpenOptions::new() .append(true) diff --git a/rama-tls/src/rustls/client/http.rs b/rama-tls/src/rustls/client/http.rs index 43bf9a82..58edece9 100644 --- a/rama-tls/src/rustls/client/http.rs +++ b/rama-tls/src/rustls/client/http.rs @@ -228,13 +228,13 @@ where conn, addr, } = self.inner.connect(ctx, req).await.map_err(Into::into)?; - let transport_ctx = ctx .get_or_try_insert_with_ctx(|ctx| req.try_ref_into_transport_ctx(ctx)) .map_err(|err| { OpaqueError::from_boxed(err.into()) .context("HttpsConnector(auto): compute transport context") - })?; + })? + .clone(); if !transport_ctx .app_protocol @@ -264,7 +264,8 @@ where "HttpsConnector(auto): attempt to secure inner connection", ); - let (stream, negotiated_params) = self.handshake(server_host, conn).await?; + let connector_data = ctx.get().cloned(); + let (stream, negotiated_params) = self.handshake(connector_data, server_host, conn).await?; tracing::trace!( authority = %transport_ctx.authority, @@ -322,7 +323,8 @@ where let server_host = transport_ctx.authority.host().clone(); - let (conn, negotiated_params) = self.handshake(server_host, conn).await?; + let connector_data = ctx.get().cloned(); + let (conn, negotiated_params) = self.handshake(connector_data, server_host, conn).await?; ctx.insert(negotiated_params); Ok(EstablishedClientConnection { @@ -372,7 +374,8 @@ where } }; - let (conn, negotiated_params) = self.handshake(server_host, conn).await?; + let connector_data = ctx.get().cloned(); + let (conn, negotiated_params) = self.handshake(connector_data, server_host, conn).await?; ctx.insert(negotiated_params); tracing::trace!("HttpsConnector(tunnel): connection secured"); @@ -390,13 +393,14 @@ where impl HttpsConnector { async fn handshake( &self, + connector_data: Option, server_host: Host, stream: T, ) -> Result<(TlsStream, NegotiatedTlsParameters), BoxError> where T: Stream + Unpin, { - let (config, server_host) = match &self.connector_data { + let (config, server_host) = match connector_data.as_ref().or(self.connector_data.as_ref()) { Some(connector_data) => { let client_config = connector_data.client_config.clone(); let server_host = connector_data.server_name().cloned().unwrap_or(server_host); diff --git a/rama-tls/src/rustls/server/service.rs b/rama-tls/src/rustls/server/service.rs index 0d3b7d8d..814b923f 100644 --- a/rama-tls/src/rustls/server/service.rs +++ b/rama-tls/src/rustls/server/service.rs @@ -73,7 +73,9 @@ where type Error = BoxError; async fn serve(&self, mut ctx: Context, stream: IO) -> Result { - let acceptor = TlsAcceptor::from(self.data.server_config.clone()); + let tls_acceptor_data = ctx.get::().unwrap_or(&self.data); + + let acceptor = TlsAcceptor::from(tls_acceptor_data.server_config.clone()); let stream = acceptor.accept(stream).await?; let (_, conn_data_ref) = stream.get_ref(); @@ -106,6 +108,8 @@ where type Error = BoxError; async fn serve(&self, mut ctx: Context, stream: IO) -> Result { + let tls_acceptor_data = ctx.get::().unwrap_or(&self.data); + let acceptor = LazyConfigAcceptor::new(Acceptor::default(), stream); let start = acceptor.await?; @@ -116,7 +120,9 @@ where SecureTransport::default() }; - let stream = start.into_stream(self.data.server_config.clone()).await?; + let stream = start + .into_stream(tls_acceptor_data.server_config.clone()) + .await?; let (_, conn_data_ref) = stream.get_ref(); ctx.insert(NegotiatedTlsParameters { protocol_version: conn_data_ref @@ -160,15 +166,18 @@ where SecureTransport::default() }; - let service_data = self - .client_config_handler - .service_data_provider - .get_service_data(accepted_client_hello) - .await - .map_err(Into::into)? - .unwrap_or_else(|| self.data.clone()); + let tls_acceptor_data = match ctx.get::() { + Some(data) => data.clone(), + None => self + .client_config_handler + .service_data_provider + .get_service_data(accepted_client_hello) + .await + .map_err(Into::into)? + .unwrap_or_else(|| self.data.clone()), + }; - let stream = start.into_stream(service_data.server_config).await?; + let stream = start.into_stream(tls_acceptor_data.server_config).await?; let (_, conn_data_ref) = stream.get_ref(); ctx.insert(NegotiatedTlsParameters {