Skip to content

Commit

Permalink
allow acceptor/connector data to be injected via ext
Browse files Browse the repository at this point in the history
  • Loading branch information
GlenDC committed Sep 19, 2024
1 parent 4a342f2 commit 2414010
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 31 deletions.
15 changes: 10 additions & 5 deletions rama-tls/src/boring/client/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,8 @@ where
.map_err(|err| {
OpaqueError::from_boxed(err.into())
.context("HttpsConnector(auto): compute transport context")
})?;
})?
.clone();

if !transport_ctx
.app_protocol
Expand All @@ -257,7 +258,8 @@ where

let host = transport_ctx.authority.host().clone();

let (stream, negotiated_params) = self.handshake(host, conn).await?;
let connector_data = ctx.get().cloned();
let (stream, negotiated_params) = self.handshake(connector_data, host, conn).await?;

tracing::trace!(
authority = %transport_ctx.authority,
Expand Down Expand Up @@ -313,7 +315,8 @@ where

let host = transport_ctx.authority.host().clone();

let (conn, negotiated_params) = self.handshake(host, conn).await?;
let connector_data = ctx.get().cloned();
let (conn, negotiated_params) = self.handshake(connector_data, host, conn).await?;
ctx.insert(negotiated_params);

Ok(EstablishedClientConnection {
Expand Down Expand Up @@ -363,7 +366,8 @@ where
}
};

let (stream, negotiated_params) = self.handshake(host, conn).await?;
let connector_data = ctx.get().cloned();
let (stream, negotiated_params) = self.handshake(connector_data, host, conn).await?;
ctx.insert(negotiated_params);

tracing::trace!("HttpsConnector(tunnel): connection secured");
Expand All @@ -381,13 +385,14 @@ where
impl<S, K> HttpsConnector<S, K> {
async fn handshake<T>(
&self,
connector_data: Option<TlsConnectorData>,
server_host: Host,
stream: T,
) -> Result<(SslStream<T>, NegotiatedTlsParameters), BoxError>
where
T: Stream + Unpin,
{
let (config, server_host) = match &self.connector_data {
let (config, server_host) = match connector_data.as_ref().or(self.connector_data.as_ref()) {
Some(connector_data) => {
let client_config = connector_data.connect_config_input.try_to_build_config()?;
let server_host = connector_data.server_name().cloned().unwrap_or(server_host);
Expand Down
23 changes: 13 additions & 10 deletions rama-tls/src/boring/server/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,11 @@ where
type Error = BoxError;

async fn serve(&self, mut ctx: Context<T>, stream: IO) -> Result<Self::Response, Self::Error> {
// allow tls acceptor data to be injected,
// e.g. useful for TLS environments where some data (such as server auth, think ACME)
// is updated at runtime, be it infrequent
let tls_config = &ctx.get::<TlsAcceptorData>().unwrap_or(&self.data).config;

let mut acceptor_builder = SslAcceptor::mozilla_intermediate_v5(SslMethod::tls_server())
.context("create boring ssl acceptor")?;

Expand All @@ -81,7 +86,7 @@ where
.set_default_verify_paths()
.context("build boring ssl acceptor: set default verify paths")?;

for (i, ca_cert) in self.data.config.cert_chain.iter().enumerate() {
for (i, ca_cert) in tls_config.cert_chain.iter().enumerate() {
if i == 0 {
acceptor_builder
.set_certificate(ca_cert.as_ref())
Expand All @@ -93,13 +98,13 @@ where
}
}
acceptor_builder
.set_private_key(self.data.config.private_key.as_ref())
.set_private_key(tls_config.private_key.as_ref())
.context("build boring ssl acceptor: set private key")?;
acceptor_builder
.check_private_key()
.context("build boring ssl acceptor: check private key")?;

if let Some(min_ver) = self.data.config.protocol_versions.iter().flatten().min() {
if let Some(min_ver) = tls_config.protocol_versions.iter().flatten().min() {
acceptor_builder
.set_min_proto_version(Some((*min_ver).try_into().map_err(|v| {
OpaqueError::from_display(format!("protocol version {v}"))
Expand All @@ -108,7 +113,7 @@ where
.context("build boring ssl acceptor: set min proto version")?;
}

if let Some(max_ver) = self.data.config.protocol_versions.iter().flatten().max() {
if let Some(max_ver) = tls_config.protocol_versions.iter().flatten().max() {
acceptor_builder
.set_max_proto_version(Some((*max_ver).try_into().map_err(|v| {
OpaqueError::from_display(format!("protocol version {v}"))
Expand All @@ -117,7 +122,7 @@ where
.context("build boring ssl acceptor: set max proto version")?;
}

for ca_cert in self.data.config.client_cert_chain.iter().flatten() {
for ca_cert in tls_config.client_cert_chain.iter().flatten() {
acceptor_builder
.add_client_ca(ca_cert)
.context("build boring ssl acceptor: set ca client cert")?;
Expand All @@ -142,16 +147,14 @@ where
None
};

if !self
.data
.config
if !tls_config
.alpn_protocols
.as_ref()
.map(|v| !v.is_empty())
.unwrap_or_default()
{
let mut buf = vec![];
for alpn in self.data.config.alpn_protocols.iter().flatten() {
for alpn in tls_config.alpn_protocols.iter().flatten() {
alpn.encode_wire_format(&mut buf)
.context("build boring ssl acceptor: encode alpn")?;
}
Expand All @@ -160,7 +163,7 @@ where
.context("build boring ssl acceptor: set alpn")?;
}

if let Some(keylog_filename) = &self.data.config.keylog_filename {
if let Some(keylog_filename) = &tls_config.keylog_filename {
// open file in append mode and write keylog to it with callback
let file = std::fs::OpenOptions::new()
.append(true)
Expand Down
16 changes: 10 additions & 6 deletions rama-tls/src/rustls/client/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,13 +228,13 @@ where
conn,
addr,
} = self.inner.connect(ctx, req).await.map_err(Into::into)?;

let transport_ctx = ctx
.get_or_try_insert_with_ctx(|ctx| req.try_ref_into_transport_ctx(ctx))
.map_err(|err| {
OpaqueError::from_boxed(err.into())
.context("HttpsConnector(auto): compute transport context")
})?;
})?
.clone();

if !transport_ctx
.app_protocol
Expand Down Expand Up @@ -264,7 +264,8 @@ where
"HttpsConnector(auto): attempt to secure inner connection",
);

let (stream, negotiated_params) = self.handshake(server_host, conn).await?;
let connector_data = ctx.get().cloned();
let (stream, negotiated_params) = self.handshake(connector_data, server_host, conn).await?;

tracing::trace!(
authority = %transport_ctx.authority,
Expand Down Expand Up @@ -322,7 +323,8 @@ where

let server_host = transport_ctx.authority.host().clone();

let (conn, negotiated_params) = self.handshake(server_host, conn).await?;
let connector_data = ctx.get().cloned();
let (conn, negotiated_params) = self.handshake(connector_data, server_host, conn).await?;
ctx.insert(negotiated_params);

Ok(EstablishedClientConnection {
Expand Down Expand Up @@ -372,7 +374,8 @@ where
}
};

let (conn, negotiated_params) = self.handshake(server_host, conn).await?;
let connector_data = ctx.get().cloned();
let (conn, negotiated_params) = self.handshake(connector_data, server_host, conn).await?;
ctx.insert(negotiated_params);

tracing::trace!("HttpsConnector(tunnel): connection secured");
Expand All @@ -390,13 +393,14 @@ where
impl<S, K> HttpsConnector<S, K> {
async fn handshake<T>(
&self,
connector_data: Option<TlsConnectorData>,
server_host: Host,
stream: T,
) -> Result<(TlsStream<T>, NegotiatedTlsParameters), BoxError>
where
T: Stream + Unpin,
{
let (config, server_host) = match &self.connector_data {
let (config, server_host) = match connector_data.as_ref().or(self.connector_data.as_ref()) {
Some(connector_data) => {
let client_config = connector_data.client_config.clone();
let server_host = connector_data.server_name().cloned().unwrap_or(server_host);
Expand Down
29 changes: 19 additions & 10 deletions rama-tls/src/rustls/server/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ where
type Error = BoxError;

async fn serve(&self, mut ctx: Context<T>, stream: IO) -> Result<Self::Response, Self::Error> {
let acceptor = TlsAcceptor::from(self.data.server_config.clone());
let tls_acceptor_data = ctx.get::<TlsAcceptorData>().unwrap_or(&self.data);

let acceptor = TlsAcceptor::from(tls_acceptor_data.server_config.clone());

let stream = acceptor.accept(stream).await?;
let (_, conn_data_ref) = stream.get_ref();
Expand Down Expand Up @@ -106,6 +108,8 @@ where
type Error = BoxError;

async fn serve(&self, mut ctx: Context<T>, stream: IO) -> Result<Self::Response, Self::Error> {
let tls_acceptor_data = ctx.get::<TlsAcceptorData>().unwrap_or(&self.data);

let acceptor = LazyConfigAcceptor::new(Acceptor::default(), stream);

let start = acceptor.await?;
Expand All @@ -116,7 +120,9 @@ where
SecureTransport::default()
};

let stream = start.into_stream(self.data.server_config.clone()).await?;
let stream = start
.into_stream(tls_acceptor_data.server_config.clone())
.await?;
let (_, conn_data_ref) = stream.get_ref();
ctx.insert(NegotiatedTlsParameters {
protocol_version: conn_data_ref
Expand Down Expand Up @@ -160,15 +166,18 @@ where
SecureTransport::default()
};

let service_data = self
.client_config_handler
.service_data_provider
.get_service_data(accepted_client_hello)
.await
.map_err(Into::into)?
.unwrap_or_else(|| self.data.clone());
let tls_acceptor_data = match ctx.get::<TlsAcceptorData>() {
Some(data) => data.clone(),
None => self
.client_config_handler
.service_data_provider
.get_service_data(accepted_client_hello)
.await
.map_err(Into::into)?
.unwrap_or_else(|| self.data.clone()),
};

let stream = start.into_stream(service_data.server_config).await?;
let stream = start.into_stream(tls_acceptor_data.server_config).await?;

let (_, conn_data_ref) = stream.get_ref();
ctx.insert(NegotiatedTlsParameters {
Expand Down

0 comments on commit 2414010

Please sign in to comment.