diff --git a/Cargo.lock b/Cargo.lock index 3fec7718ce..317af86fff 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5664,6 +5664,7 @@ dependencies = [ "rustls", "rustls-webpki", "serde", + "time 0.3.36", "tokio", "tokio-util", "tracing", @@ -5688,6 +5689,7 @@ dependencies = [ "rustls-pki-types", "rustls-webpki", "secrecy", + "time 0.3.36", "tokio", "tokio-util", "tracing", @@ -5745,6 +5747,7 @@ dependencies = [ "rustls-webpki", "secrecy", "socket2 0.5.7", + "time 0.3.36", "tls-listener", "tokio", "tokio-rustls", diff --git a/Cargo.toml b/Cargo.toml index 721d1bced1..b5853f4897 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -167,6 +167,7 @@ socket2 = { version = "0.5.7", features = ["all"] } stop-token = "0.7.0" syn = "2.0" tide = "0.16.0" +time = "0.3.36" token-cell = { version = "1.5.0", default-features = false } tokio = { version = "1.40.0", default-features = false } # Default features are disabled due to some crates' requirements tokio-util = "0.7.12" diff --git a/DEFAULT_CONFIG.json5 b/DEFAULT_CONFIG.json5 index bb083a7e31..1285c703bd 100644 --- a/DEFAULT_CONFIG.json5 +++ b/DEFAULT_CONFIG.json5 @@ -466,6 +466,10 @@ // This could be dangerous because your CA can have signed a server cert for foo.com, that's later being used to host a server at baz.com. If you wan't your // ca to verify that the server at baz.com is actually baz.com, let this be true (default). verify_name_on_connect: true, + // Whether or not to close links when remote certificates expires. + // If set to true, links that require certificates (tls/quic) will automatically disconnect when the time of expiration of the remote certificate chain is reached + // note that mTLS (client authentication) is required for a listener to disconnect a client on expiration + close_link_on_expiration: false, }, }, /// Shared memory configuration. diff --git a/commons/zenoh-config/src/lib.rs b/commons/zenoh-config/src/lib.rs index b3df5e3cb5..0b9777f233 100644 --- a/commons/zenoh-config/src/lib.rs +++ b/commons/zenoh-config/src/lib.rs @@ -483,6 +483,7 @@ validated_struct::validator! { connect_private_key: Option, connect_certificate: Option, verify_name_on_connect: Option, + close_link_on_expiration: Option, // Skip serializing field because they contain secrets #[serde(skip_serializing)] root_ca_certificate_base64: Option, diff --git a/io/zenoh-link-commons/Cargo.toml b/io/zenoh-link-commons/Cargo.toml index 5f8f91aa61..a1bec81cbd 100644 --- a/io/zenoh-link-commons/Cargo.toml +++ b/io/zenoh-link-commons/Cargo.toml @@ -35,6 +35,7 @@ futures = { workspace = true } rustls = { workspace = true, optional = true } rustls-webpki = { workspace = true, optional = true } serde = { workspace = true, features = ["default"] } +time = { workspace = true } tokio = { workspace = true, features = [ "fs", "io-util", diff --git a/io/zenoh-link-commons/src/tls.rs b/io/zenoh-link-commons/src/tls.rs index 427880b812..4473607a0a 100644 --- a/io/zenoh-link-commons/src/tls.rs +++ b/io/zenoh-link-commons/src/tls.rs @@ -86,3 +86,124 @@ impl WebPkiVerifierAnyServerName { Self { roots } } } + +pub mod expiration { + use std::{ + net::SocketAddr, + sync::{atomic::AtomicBool, Weak}, + }; + + use async_trait::async_trait; + use time::OffsetDateTime; + use tokio::{sync::Mutex as AsyncMutex, task::JoinHandle}; + use tokio_util::sync::CancellationToken; + use zenoh_result::ZResult; + + #[async_trait] + pub trait LinkWithCertExpiration: Send + Sync { + async fn expire(&self) -> ZResult<()>; + } + + #[derive(Debug)] + pub struct LinkCertExpirationManager { + token: CancellationToken, + handle: AsyncMutex>>>, + /// Closing the link is a critical section that requires exclusive access to expiration_task + /// or the transport. `link_closing` is used to synchronize the access to this operation. + link_closing: AtomicBool, + } + + impl LinkCertExpirationManager { + pub fn new( + link: Weak, + src_addr: SocketAddr, + dst_addr: SocketAddr, + link_type: &'static str, + expiration_time: OffsetDateTime, + ) -> Self { + let token = CancellationToken::new(); + let handle = zenoh_runtime::ZRuntime::Acceptor.spawn(expiration_task( + link, + src_addr, + dst_addr, + link_type, + expiration_time, + token.clone(), + )); + Self { + token, + handle: AsyncMutex::new(Some(handle)), + link_closing: AtomicBool::new(false), + } + } + + /// Takes exclusive access to closing the link. + /// + /// Returns `true` if successful, `false` if another task is (or finished) closing the link. + pub fn set_closing(&self) -> bool { + !self + .link_closing + .swap(true, std::sync::atomic::Ordering::Relaxed) + } + + /// Sends cancelation signal to expiration_task + pub fn cancel_expiration_task(&self) { + self.token.cancel() + } + + /// Waits for expiration task to complete, returning its return value. + pub async fn wait_for_expiration_task(&self) -> ZResult<()> { + let mut lock = self.handle.lock().await; + let handle = lock.take().expect("handle should be set"); + handle.await? + } + } + + async fn expiration_task( + link: Weak, + src_addr: SocketAddr, + dst_addr: SocketAddr, + link_type: &'static str, + expiration_time: OffsetDateTime, + token: CancellationToken, + ) -> ZResult<()> { + tracing::trace!( + "Expiration task started for {} link {:?} => {:?}", + link_type.to_uppercase(), + src_addr, + dst_addr, + ); + tokio::select! { + _ = token.cancelled() => {}, + _ = sleep_until_date(expiration_time) => { + // expire the link + if let Some(link) = link.upgrade() { + tracing::warn!( + "Closing {} link {:?} => {:?} : remote certificate chain expired", + link_type.to_uppercase(), + src_addr, + dst_addr, + ); + return link.expire().await; + } + }, + } + Ok(()) + } + + async fn sleep_until_date(wakeup_time: OffsetDateTime) { + const MAX_SLEEP_DURATION: tokio::time::Duration = tokio::time::Duration::from_secs(600); + loop { + let now = OffsetDateTime::now_utc(); + if wakeup_time <= now { + break; + } + // next sleep duration is the minimum between MAX_SLEEP_DURATION and the duration till wakeup + // this mitigates the unsoundness of using `tokio::time::sleep` with long durations + let wakeup_duration = std::time::Duration::try_from(wakeup_time - now) + .expect("wakeup_time should be greater than now"); + let sleep_duration = tokio::time::Duration::min(MAX_SLEEP_DURATION, wakeup_duration); + tokio::time::sleep(sleep_duration).await; + } + } +} diff --git a/io/zenoh-links/zenoh-link-quic/Cargo.toml b/io/zenoh-links/zenoh-link-quic/Cargo.toml index b01f5e4261..6132abeaca 100644 --- a/io/zenoh-links/zenoh-link-quic/Cargo.toml +++ b/io/zenoh-links/zenoh-link-quic/Cargo.toml @@ -33,6 +33,7 @@ rustls-pemfile = { workspace = true } rustls-pki-types = { workspace = true } rustls-webpki = { workspace = true } secrecy = { workspace = true } +time = { workspace = true } tokio = { workspace = true, features = [ "fs", "io-util", diff --git a/io/zenoh-links/zenoh-link-quic/src/lib.rs b/io/zenoh-links/zenoh-link-quic/src/lib.rs index abaefd199c..0e94915284 100644 --- a/io/zenoh-links/zenoh-link-quic/src/lib.rs +++ b/io/zenoh-links/zenoh-link-quic/src/lib.rs @@ -112,4 +112,7 @@ pub mod config { pub const TLS_VERIFY_NAME_ON_CONNECT: &str = "verify_name_on_connect"; pub const TLS_VERIFY_NAME_ON_CONNECT_DEFAULT: bool = true; + + pub const TLS_CLOSE_LINK_ON_EXPIRATION: &str = "close_link_on_expiration"; + pub const TLS_CLOSE_LINK_ON_EXPIRATION_DEFAULT: bool = false; } diff --git a/io/zenoh-links/zenoh-link-quic/src/unicast.rs b/io/zenoh-links/zenoh-link-quic/src/unicast.rs index ea6ce646cc..3618a7a625 100644 --- a/io/zenoh-links/zenoh-link-quic/src/unicast.rs +++ b/io/zenoh-links/zenoh-link-quic/src/unicast.rs @@ -21,13 +21,16 @@ use std::{ use async_trait::async_trait; use quinn::crypto::rustls::{QuicClientConfig, QuicServerConfig}; +use time::OffsetDateTime; use tokio::sync::Mutex as AsyncMutex; use tokio_util::sync::CancellationToken; -use x509_parser::prelude::*; +use x509_parser::prelude::{FromDer, X509Certificate}; use zenoh_core::zasynclock; use zenoh_link_commons::{ - get_ip_interface_names, LinkAuthId, LinkAuthType, LinkManagerUnicastTrait, LinkUnicast, - LinkUnicastTrait, ListenersUnicastIP, NewLinkChannelSender, + get_ip_interface_names, + tls::expiration::{LinkCertExpirationManager, LinkWithCertExpiration}, + LinkAuthId, LinkAuthType, LinkManagerUnicastTrait, LinkUnicast, LinkUnicastTrait, + ListenersUnicastIP, NewLinkChannelSender, }; use zenoh_protocol::{ core::{EndPoint, Locator}, @@ -48,6 +51,7 @@ pub struct LinkUnicastQuic { send: AsyncMutex, recv: AsyncMutex, auth_identifier: LinkAuthId, + expiration_manager: Option, } impl LinkUnicastQuic { @@ -58,6 +62,7 @@ impl LinkUnicastQuic { send: quinn::SendStream, recv: quinn::RecvStream, auth_identifier: LinkAuthId, + expiration_manager: Option, ) -> LinkUnicastQuic { // Build the Quic object LinkUnicastQuic { @@ -68,12 +73,10 @@ impl LinkUnicastQuic { send: AsyncMutex::new(send), recv: AsyncMutex::new(recv), auth_identifier, + expiration_manager, } } -} -#[async_trait] -impl LinkUnicastTrait for LinkUnicastQuic { async fn close(&self) -> ZResult<()> { tracing::trace!("Closing QUIC link: {}", self); // Flush the QUIC stream @@ -84,6 +87,24 @@ impl LinkUnicastTrait for LinkUnicastQuic { self.connection.close(quinn::VarInt::from_u32(0), &[0]); Ok(()) } +} + +#[async_trait] +impl LinkUnicastTrait for LinkUnicastQuic { + async fn close(&self) -> ZResult<()> { + if let Some(expiration_manager) = &self.expiration_manager { + if !expiration_manager.set_closing() { + // expiration_task is closing link, return its returned ZResult to Transport + return expiration_manager.wait_for_expiration_task().await; + } + // cancel the expiration task and close link + expiration_manager.cancel_expiration_task(); + let res = self.close().await; + let _ = expiration_manager.wait_for_expiration_task().await; + return res; + } + self.close().await + } async fn write(&self, buffer: &[u8]) -> ZResult { let mut guard = zasynclock!(self.send); @@ -167,6 +188,21 @@ impl LinkUnicastTrait for LinkUnicastQuic { } } +#[async_trait] +impl LinkWithCertExpiration for LinkUnicastQuic { + async fn expire(&self) -> ZResult<()> { + let expiration_manager = self + .expiration_manager + .as_ref() + .expect("expiration_manager should be set"); + if expiration_manager.set_closing() { + return self.close().await; + } + // Transport is already closing the link + Ok(()) + } +} + impl Drop for LinkUnicastQuic { fn drop(&mut self) { self.connection.close(quinn::VarInt::from_u32(0), &[0]); @@ -219,17 +255,17 @@ impl LinkManagerUnicastTrait for LinkManagerUnicastQuic { .ok_or("Endpoints must be of the form quic/
:")?; let epconf = endpoint.config(); - let addr = get_quic_addr(&epaddr).await?; + let dst_addr = get_quic_addr(&epaddr).await?; // Initialize the QUIC connection let mut client_crypto = TlsClientConfig::new(&epconf) .await - .map_err(|e| zerror!("Cannot create a new QUIC client on {addr}: {e}"))?; + .map_err(|e| zerror!("Cannot create a new QUIC client on {dst_addr}: {e}"))?; client_crypto.client_config.alpn_protocols = ALPN_QUIC_HTTP.iter().map(|&x| x.into()).collect(); - let ip_addr: IpAddr = if addr.is_ipv4() { + let ip_addr: IpAddr = if dst_addr.is_ipv4() { Ipv4Addr::UNSPECIFIED.into() } else { Ipv6Addr::UNSPECIFIED.into() @@ -248,7 +284,7 @@ impl LinkManagerUnicastTrait for LinkManagerUnicastQuic { .map_err(|e| zerror!("Can not create a new QUIC link bound to {}: {}", host, e))?; let quic_conn = quic_endpoint - .connect(addr, host) + .connect(dst_addr, host) .map_err(|e| zerror!("Can not create a new QUIC link bound to {}: {}", host, e))? .await .map_err(|e| zerror!("Can not create a new QUIC link bound to {}: {}", host, e))?; @@ -259,15 +295,31 @@ impl LinkManagerUnicastTrait for LinkManagerUnicastQuic { .map_err(|e| zerror!("Can not create a new QUIC link bound to {}: {}", host, e))?; let auth_id = get_cert_common_name(&quic_conn)?; - - let link = Arc::new(LinkUnicastQuic::new( - quic_conn, - src_addr, - endpoint.into(), - send, - recv, - auth_id.into(), - )); + let certchain_expiration_time = + get_cert_chain_expiration(&quic_conn)?.expect("server should have certificate chain"); + + let link = Arc::::new_cyclic(|weak_link| { + let mut expiration_manager = None; + if client_crypto.tls_close_link_on_expiration { + // setup expiration manager + expiration_manager = Some(LinkCertExpirationManager::new( + weak_link.clone(), + src_addr, + dst_addr, + QUIC_LOCATOR_PREFIX, + certchain_expiration_time, + )) + } + LinkUnicastQuic::new( + quic_conn, + src_addr, + endpoint.into(), + send, + recv, + auth_id.into(), + expiration_manager, + ) + }); Ok(LinkUnicast(link)) } @@ -337,7 +389,15 @@ impl LinkManagerUnicastTrait for LinkManagerUnicastQuic { let token = token.clone(); let manager = self.manager.clone(); - async move { accept_task(quic_endpoint, token, manager).await } + async move { + accept_task( + quic_endpoint, + token, + manager, + server_crypto.tls_close_link_on_expiration, + ) + .await + } }; // Initialize the QuicAcceptor @@ -369,6 +429,7 @@ async fn accept_task( quic_endpoint: quinn::Endpoint, token: CancellationToken, manager: NewLinkChannelSender, + tls_close_link_on_expiration: bool, ) -> ZResult<()> { async fn accept(acceptor: quinn::Accept<'_>) -> ZResult { let qc = acceptor @@ -416,19 +477,47 @@ async fn accept_task( } }; let dst_addr = quic_conn.remote_address(); + let dst_locator = Locator::new(QUIC_LOCATOR_PREFIX, dst_addr.to_string(), "")?; // Get Quic auth identifier let auth_id = get_cert_common_name(&quic_conn)?; + // Get certificate chain expiration + let mut maybe_expiration_time = None; + if tls_close_link_on_expiration { + match get_cert_chain_expiration(&quic_conn)? { + exp @ Some(_) => maybe_expiration_time = exp, + None => tracing::warn!( + "Cannot monitor expiration for QUIC link {:?} => {:?} : client does not have certificates", + src_addr, + dst_addr, + ), + } + } + tracing::debug!("Accepted QUIC connection on {:?}: {:?}", src_addr, dst_addr); // Create the new link object - let link = Arc::new(LinkUnicastQuic::new( - quic_conn, - src_addr, - Locator::new(QUIC_LOCATOR_PREFIX, dst_addr.to_string(), "")?, - send, - recv, - auth_id.into() - )); + let link = Arc::::new_cyclic(|weak_link| { + let mut expiration_manager = None; + if let Some(certchain_expiration_time) = maybe_expiration_time { + // setup expiration manager + expiration_manager = Some(LinkCertExpirationManager::new( + weak_link.clone(), + src_addr, + dst_addr, + QUIC_LOCATOR_PREFIX, + certchain_expiration_time, + )); + } + LinkUnicastQuic::new( + quic_conn, + src_addr, + dst_locator, + send, + recv, + auth_id.into(), + expiration_manager, + ) + }); // Communicate the new link to the initial transport manager if let Err(e) = manager.send_async(LinkUnicast(link)).await { @@ -475,6 +564,24 @@ fn get_cert_common_name(conn: &quinn::Connection) -> ZResult { Ok(auth_id) } +/// Returns the minimum value of the `not_after` field in the remote certificate chain. +/// Returns `None` if the remote certificate chain is empty +fn get_cert_chain_expiration(conn: &quinn::Connection) -> ZResult> { + let mut link_expiration: Option = None; + if let Some(pi) = conn.peer_identity() { + if let Ok(remote_certs) = pi.downcast::>() { + for cert in *remote_certs { + let (_, cert) = X509Certificate::from_der(cert.as_ref())?; + let cert_expiration = cert.validity().not_after.to_datetime(); + link_expiration = link_expiration + .map(|current_min| current_min.min(cert_expiration)) + .or(Some(cert_expiration)); + } + } + } + Ok(link_expiration) +} + #[derive(Debug, Clone)] struct QuicAuthId { auth_value: Option, diff --git a/io/zenoh-links/zenoh-link-quic/src/utils.rs b/io/zenoh-links/zenoh-link-quic/src/utils.rs index afd361f634..3a7c0b26a8 100644 --- a/io/zenoh-links/zenoh-link-quic/src/utils.rs +++ b/io/zenoh-links/zenoh-link-quic/src/utils.rs @@ -141,12 +141,21 @@ impl ConfigurationInspector for TlsConfigurator { false => ps.push((TLS_VERIFY_NAME_ON_CONNECT, "false")), }; + match c + .close_link_on_expiration() + .unwrap_or(TLS_CLOSE_LINK_ON_EXPIRATION_DEFAULT) + { + true => ps.push((TLS_CLOSE_LINK_ON_EXPIRATION, "true")), + false => ps.push((TLS_CLOSE_LINK_ON_EXPIRATION, "false")), + } + Ok(parameters::from_iter(ps.drain(..))) } } pub(crate) struct TlsServerConfig { pub(crate) server_config: ServerConfig, + pub(crate) tls_close_link_on_expiration: bool, } impl TlsServerConfig { @@ -157,6 +166,12 @@ impl TlsServerConfig { .map_err(|_| zerror!("Unknown enable mTLS argument: {}", s))?, None => false, }; + let tls_close_link_on_expiration: bool = match config.get(TLS_CLOSE_LINK_ON_EXPIRATION) { + Some(s) => s + .parse() + .map_err(|_| zerror!("Unknown close on expiration argument: {}", s))?, + None => TLS_CLOSE_LINK_ON_EXPIRATION_DEFAULT, + }; let tls_server_private_key = TlsServerConfig::load_tls_private_key(config).await?; let tls_server_certificate = TlsServerConfig::load_tls_certificate(config).await?; @@ -215,7 +230,10 @@ impl TlsServerConfig { .with_single_cert(certs, keys.remove(0)) .map_err(|e| zerror!(e))? }; - Ok(TlsServerConfig { server_config: sc }) + Ok(TlsServerConfig { + server_config: sc, + tls_close_link_on_expiration, + }) } async fn load_tls_private_key(config: &Config<'_>) -> ZResult> { @@ -241,6 +259,7 @@ impl TlsServerConfig { pub(crate) struct TlsClientConfig { pub(crate) client_config: ClientConfig, + pub(crate) tls_close_link_on_expiration: bool, } impl TlsClientConfig { @@ -265,6 +284,13 @@ impl TlsClientConfig { None => false, }; + let tls_close_link_on_expiration: bool = match config.get(TLS_CLOSE_LINK_ON_EXPIRATION) { + Some(s) => s + .parse() + .map_err(|_| zerror!("Unknown close on expiration argument: {}", s))?, + None => TLS_CLOSE_LINK_ON_EXPIRATION_DEFAULT, + }; + // Allows mixed user-generated CA and webPKI CA tracing::debug!("Loading default Web PKI certificates."); let mut root_cert_store = RootCertStore { @@ -351,7 +377,10 @@ impl TlsClientConfig { .with_no_client_auth() } }; - Ok(TlsClientConfig { client_config: cc }) + Ok(TlsClientConfig { + client_config: cc, + tls_close_link_on_expiration, + }) } async fn load_tls_private_key(config: &Config<'_>) -> ZResult> { diff --git a/io/zenoh-links/zenoh-link-tls/Cargo.toml b/io/zenoh-links/zenoh-link-tls/Cargo.toml index 5fc8d3ad69..281d890beb 100644 --- a/io/zenoh-links/zenoh-link-tls/Cargo.toml +++ b/io/zenoh-links/zenoh-link-tls/Cargo.toml @@ -36,6 +36,7 @@ socket2 = { workspace = true } tokio = { workspace = true, features = ["fs", "io-util", "net", "sync"] } tokio-rustls = { workspace = true } tokio-util = { workspace = true, features = ["rt"] } +time = { workspace = true } tracing = { workspace = true } x509-parser = { workspace = true } webpki-roots = { workspace = true } diff --git a/io/zenoh-links/zenoh-link-tls/src/lib.rs b/io/zenoh-links/zenoh-link-tls/src/lib.rs index a547c5d77f..4710cfd332 100644 --- a/io/zenoh-links/zenoh-link-tls/src/lib.rs +++ b/io/zenoh-links/zenoh-link-tls/src/lib.rs @@ -109,6 +109,9 @@ pub mod config { pub const TLS_VERIFY_NAME_ON_CONNECT: &str = "verify_name_on_connect"; pub const TLS_VERIFY_NAME_ON_CONNECT_DEFAULT: bool = true; + pub const TLS_CLOSE_LINK_ON_EXPIRATION: &str = "close_link_on_expiration"; + pub const TLS_CLOSE_LINK_ON_EXPIRATION_DEFAULT: bool = false; + /// The time duration in milliseconds to wait for the TLS handshake to complete. pub const TLS_HANDSHAKE_TIMEOUT_MS: &str = "tls_handshake_timeout_ms"; pub const TLS_HANDSHAKE_TIMEOUT_MS_DEFAULT: u64 = 10_000; diff --git a/io/zenoh-links/zenoh-link-tls/src/unicast.rs b/io/zenoh-links/zenoh-link-tls/src/unicast.rs index 3654e7800f..046288800e 100644 --- a/io/zenoh-links/zenoh-link-tls/src/unicast.rs +++ b/io/zenoh-links/zenoh-link-tls/src/unicast.rs @@ -14,6 +14,7 @@ use std::{cell::UnsafeCell, convert::TryInto, fmt, net::SocketAddr, sync::Arc, time::Duration}; use async_trait::async_trait; +use time::OffsetDateTime; use tokio::{ io::{AsyncReadExt, AsyncWriteExt}, net::{TcpListener, TcpStream}, @@ -21,11 +22,13 @@ use tokio::{ }; use tokio_rustls::{TlsAcceptor, TlsConnector, TlsStream}; use tokio_util::sync::CancellationToken; -use x509_parser::prelude::*; +use x509_parser::prelude::{FromDer, X509Certificate}; use zenoh_core::zasynclock; use zenoh_link_commons::{ - get_ip_interface_names, LinkAuthId, LinkAuthType, LinkManagerUnicastTrait, LinkUnicast, - LinkUnicastTrait, ListenersUnicastIP, NewLinkChannelSender, + get_ip_interface_names, + tls::expiration::{LinkCertExpirationManager, LinkWithCertExpiration}, + LinkAuthId, LinkAuthType, LinkManagerUnicastTrait, LinkUnicast, LinkUnicastTrait, + ListenersUnicastIP, NewLinkChannelSender, }; use zenoh_protocol::{ core::{EndPoint, Locator}, @@ -62,6 +65,7 @@ pub struct LinkUnicastTls { read_mtx: AsyncMutex<()>, auth_identifier: LinkAuthId, mtu: BatchSize, + expiration_manager: Option, } unsafe impl Send for LinkUnicastTls {} @@ -73,6 +77,7 @@ impl LinkUnicastTls { src_addr: SocketAddr, dst_addr: SocketAddr, auth_identifier: LinkAuthId, + expiration_manager: Option, ) -> LinkUnicastTls { let (tcp_stream, _) = socket.get_ref(); // Set the TLS nodelay option @@ -131,6 +136,7 @@ impl LinkUnicastTls { read_mtx: AsyncMutex::new(()), auth_identifier, mtu, + expiration_manager, } } @@ -141,10 +147,7 @@ impl LinkUnicastTls { fn get_mut_socket(&self) -> &mut TlsStream { unsafe { &mut *self.inner.get() } } -} -#[async_trait] -impl LinkUnicastTrait for LinkUnicastTls { async fn close(&self) -> ZResult<()> { tracing::trace!("Closing TLS link: {}", self); // Flush the TLS stream @@ -158,6 +161,24 @@ impl LinkUnicastTrait for LinkUnicastTls { tracing::trace!("TLS link shutdown {}: {:?}", self, res); res.map_err(|e| zerror!(e).into()) } +} + +#[async_trait] +impl LinkUnicastTrait for LinkUnicastTls { + async fn close(&self) -> ZResult<()> { + if let Some(expiration_manager) = &self.expiration_manager { + if !expiration_manager.set_closing() { + // expiration_task is closing link, return its returned ZResult to Transport + return expiration_manager.wait_for_expiration_task().await; + } + // cancel the expiration task and close link + expiration_manager.cancel_expiration_task(); + let res = self.close().await; + let _ = expiration_manager.wait_for_expiration_task().await; + return res; + } + self.close().await + } async fn write(&self, buffer: &[u8]) -> ZResult { let _guard = zasynclock!(self.write_mtx); @@ -232,6 +253,21 @@ impl LinkUnicastTrait for LinkUnicastTls { } } +#[async_trait] +impl LinkWithCertExpiration for LinkUnicastTls { + async fn expire(&self) -> ZResult<()> { + let expiration_manager = self + .expiration_manager + .as_ref() + .expect("expiration_manager should be set"); + if expiration_manager.set_closing() { + return self.close().await; + } + // Transport is already closing the link + Ok(()) + } +} + impl Drop for LinkUnicastTls { fn drop(&mut self) { // Close the underlying TCP stream @@ -326,15 +362,31 @@ impl LinkManagerUnicastTrait for LinkManagerUnicastTls { let (_, tls_conn) = tls_stream.get_ref(); let auth_identifier = get_server_cert_common_name(tls_conn)?; + let certchain_expiration_time = get_cert_chain_expiration(&tls_conn.peer_certificates())? + .expect("server should have certificate chain"); let tls_stream = TlsStream::Client(tls_stream); - let link = Arc::new(LinkUnicastTls::new( - tls_stream, - src_addr, - dst_addr, - auth_identifier.into(), - )); + let link = Arc::::new_cyclic(|weak_link| { + let mut expiration_manager = None; + if client_config.tls_close_link_on_expiration { + // setup expiration manager + expiration_manager = Some(LinkCertExpirationManager::new( + weak_link.clone(), + src_addr, + dst_addr, + TLS_LOCATOR_PREFIX, + certchain_expiration_time, + )) + } + LinkUnicastTls::new( + tls_stream, + dst_addr, + src_addr, + auth_identifier.into(), + expiration_manager, + ) + }); Ok(LinkUnicast(link)) } @@ -376,6 +428,7 @@ impl LinkManagerUnicastTrait for LinkManagerUnicastTls { token, manager, tls_server_config.tls_handshake_timeout, + tls_server_config.tls_close_link_on_expiration, ) .await } @@ -416,6 +469,7 @@ async fn accept_task( token: CancellationToken, manager: NewLinkChannelSender, tls_handshake_timeout: Duration, + tls_close_link_on_expiration: bool, ) -> ZResult<()> { let src_addr = socket.local_addr().map_err(|e| { let e = zerror!("Can not accept TLS connections: {}", e); @@ -445,14 +499,41 @@ async fn accept_task( }; let auth_identifier = get_client_cert_common_name(tls_conn)?; + // Get certificate chain expiration + let mut maybe_expiration_time = None; + if tls_close_link_on_expiration { + match get_cert_chain_expiration(&tls_conn.peer_certificates())? { + exp @ Some(_) => maybe_expiration_time = exp, + None => tracing::warn!( + "Cannot monitor expiration for TLS link {:?} => {:?} : client does not have certificates", + src_addr, + dst_addr, + ), + } + } + tracing::debug!("Accepted TLS connection on {:?}: {:?}", src_addr, dst_addr); // Create the new link object - let link = Arc::new(LinkUnicastTls::new( - tokio_rustls::TlsStream::Server(tls_stream), - src_addr, - dst_addr, - auth_identifier.into(), - )); + let link = Arc::::new_cyclic(|weak_link| { + let mut expiration_manager = None; + if let Some(certchain_expiration_time) = maybe_expiration_time { + // setup expiration manager + expiration_manager = Some(LinkCertExpirationManager::new( + weak_link.clone(), + src_addr, + dst_addr, + TLS_LOCATOR_PREFIX, + certchain_expiration_time, + )); + } + LinkUnicastTls::new( + tokio_rustls::TlsStream::Server(tls_stream), + dst_addr, + src_addr, + auth_identifier.into(), + expiration_manager, + ) + }); // Communicate the new link to the initial transport manager if let Err(e) = manager.send_async(LinkUnicast(link)).await { @@ -517,6 +598,24 @@ fn get_server_cert_common_name(tls_conn: &rustls::ClientConnection) -> ZResult, +) -> ZResult> { + let mut link_expiration: Option = None; + if let Some(remote_certs) = cert_chain { + for cert in *remote_certs { + let (_, cert) = X509Certificate::from_der(cert.as_ref())?; + let cert_expiration = cert.validity().not_after.to_datetime(); + link_expiration = link_expiration + .map(|current_min| current_min.min(cert_expiration)) + .or(Some(cert_expiration)); + } + } + Ok(link_expiration) +} + struct TlsAuthId { auth_value: Option, } diff --git a/io/zenoh-links/zenoh-link-tls/src/utils.rs b/io/zenoh-links/zenoh-link-tls/src/utils.rs index b6e2c69578..8f8f766024 100644 --- a/io/zenoh-links/zenoh-link-tls/src/utils.rs +++ b/io/zenoh-links/zenoh-link-tls/src/utils.rs @@ -144,6 +144,14 @@ impl ConfigurationInspector for TlsConfigurator { false => ps.push((TLS_VERIFY_NAME_ON_CONNECT, "false")), }; + match c + .close_link_on_expiration() + .unwrap_or(TLS_CLOSE_LINK_ON_EXPIRATION_DEFAULT) + { + true => ps.push((TLS_CLOSE_LINK_ON_EXPIRATION, "true")), + false => ps.push((TLS_CLOSE_LINK_ON_EXPIRATION, "false")), + } + Ok(parameters::from_iter(ps.drain(..))) } } @@ -151,6 +159,7 @@ impl ConfigurationInspector for TlsConfigurator { pub(crate) struct TlsServerConfig { pub(crate) server_config: ServerConfig, pub(crate) tls_handshake_timeout: Duration, + pub(crate) tls_close_link_on_expiration: bool, } impl TlsServerConfig { @@ -161,6 +170,12 @@ impl TlsServerConfig { .map_err(|_| zerror!("Unknown enable mTLS argument: {}", s))?, None => false, }; + let tls_close_link_on_expiration: bool = match config.get(TLS_CLOSE_LINK_ON_EXPIRATION) { + Some(s) => s + .parse() + .map_err(|_| zerror!("Unknown close on expiration argument: {}", s))?, + None => TLS_CLOSE_LINK_ON_EXPIRATION_DEFAULT, + }; let tls_server_private_key = TlsServerConfig::load_tls_private_key(config).await?; let tls_server_certificate = TlsServerConfig::load_tls_certificate(config).await?; @@ -231,6 +246,7 @@ impl TlsServerConfig { Ok(TlsServerConfig { server_config: sc, tls_handshake_timeout, + tls_close_link_on_expiration, }) } @@ -257,6 +273,7 @@ impl TlsServerConfig { pub(crate) struct TlsClientConfig { pub(crate) client_config: ClientConfig, + pub(crate) tls_close_link_on_expiration: bool, } impl TlsClientConfig { @@ -281,6 +298,13 @@ impl TlsClientConfig { None => false, }; + let tls_close_link_on_expiration: bool = match config.get(TLS_CLOSE_LINK_ON_EXPIRATION) { + Some(s) => s + .parse() + .map_err(|_| zerror!("Unknown close on expiration argument: {}", s))?, + None => TLS_CLOSE_LINK_ON_EXPIRATION_DEFAULT, + }; + // Allows mixed user-generated CA and webPKI CA tracing::debug!("Loading default Web PKI certificates."); let mut root_cert_store = RootCertStore { @@ -367,7 +391,10 @@ impl TlsClientConfig { .with_no_client_auth() } }; - Ok(TlsClientConfig { client_config: cc }) + Ok(TlsClientConfig { + client_config: cc, + tls_close_link_on_expiration, + }) } async fn load_tls_private_key(config: &Config<'_>) -> ZResult> {