From 89662fdd258850ab074482fa1d3a104442b0a0dd Mon Sep 17 00:00:00 2001 From: Matt Mastracci Date: Thu, 27 Jun 2024 14:30:59 -0600 Subject: [PATCH] Set keep-alive on TCP sockets (#330) Migrations and other long-running operations may leave a TCP connection active but without any traffic. To prevent the operating systems on either end from closing the connections (or any other system that monitors for activity on TCP streams to determine when they should be closed), we can set the keep-alive flag on the client. --- edgedb-tokio/Cargo.toml | 1 + edgedb-tokio/src/builder.rs | 57 ++++++++++++++++++++++++++---- edgedb-tokio/src/credentials.rs | 3 +- edgedb-tokio/src/lib.rs | 2 +- edgedb-tokio/src/raw/connection.rs | 18 ++++++++++ flake.nix | 2 +- 6 files changed, 73 insertions(+), 10 deletions(-) diff --git a/edgedb-tokio/Cargo.toml b/edgedb-tokio/Cargo.toml index e97426de..260edced 100644 --- a/edgedb-tokio/Cargo.toml +++ b/edgedb-tokio/Cargo.toml @@ -50,6 +50,7 @@ once_cell = "1.9.0" tokio-stream = { version = "0.1.11", optional = true } base64 = "0.22.1" crc16 = "0.4.0" +socket2 = "0.5" [target.'cfg(target_family="unix")'.dev-dependencies] command-fds = "0.3.0" diff --git a/edgedb-tokio/src/builder.rs b/edgedb-tokio/src/builder.rs index 17ad0e59..49afe629 100644 --- a/edgedb-tokio/src/builder.rs +++ b/edgedb-tokio/src/builder.rs @@ -26,6 +26,7 @@ use crate::tls; pub const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(10); pub const DEFAULT_WAIT: Duration = Duration::from_secs(30); +pub const DEFAULT_TCP_KEEPALIVE: Duration = Duration::from_secs(60); pub const DEFAULT_POOL_SIZE: usize = 10; pub const DEFAULT_HOST: &str = "localhost"; pub const DEFAULT_PORT: u16 = 5656; @@ -44,13 +45,14 @@ static PORT_WARN: std::sync::Once = std::sync::Once::new(); type Verifier = Arc; /// Client security mode. -#[derive(Debug, Clone, Copy)] +#[derive(Default, Debug, Clone, Copy)] pub enum ClientSecurity { /// Disable security checks InsecureDevMode, /// Always verify domain an certificate Strict, /// Verify domain only if no specific certificate is configured + #[default] Default, } @@ -61,6 +63,28 @@ pub enum CloudCerts { Local, } +/// TCP keepalive configuration. +#[derive(Default, Debug, Clone, Copy)] +pub enum TcpKeepalive { + /// Disable TCP keepalive probes. + Disabled, + /// Explicit duration between TCP keepalive probes. + Explicit(Duration), + /// Default: 60 seconds. + #[default] + Default, +} + +impl TcpKeepalive { + fn as_keepalive(&self) -> Option { + match self { + TcpKeepalive::Disabled => None, + TcpKeepalive::Default => Some(DEFAULT_TCP_KEEPALIVE), + TcpKeepalive::Explicit(duration) => Some(*duration), + } + } +} + /// A builder used to create connections. #[derive(Debug, Clone, Default)] pub struct Builder { @@ -83,6 +107,7 @@ pub struct Builder { wait_until_available: Option, admin: bool, connect_timeout: Option, + tcp_keepalive: Option, secret_key: Option, cloud_profile: Option, @@ -121,6 +146,9 @@ pub(crate) struct ConfigInner { #[allow(dead_code)] // used only on unstable feature pub creds_file_outdated: bool, + // Whether to set TCP keepalive or not + pub tcp_keepalive: Option, + // Pool configuration pub max_concurrency: Option, @@ -785,6 +813,21 @@ impl Builder { self } + /// Sets the TCP keepalive interval and time for the database connection to + /// ensure that the remote end of the connection is still alive, and to + /// inform any network intermediaries that this connection is not idle. By + /// default, a keepalive probe will be sent once every 60 seconds once the + /// connection has been idle for 60 seconds. + /// + /// Note: If the connection is not made over a TCP socket, this value will + /// be unused. If the current platform does not support explicit TCP + /// keep-alive intervals on the socket, keepalives will be enabled and the + /// operating-system default for the intervals will be used. + pub fn tcp_keepalive(&mut self, tcp_keepalive: TcpKeepalive) -> &mut Self { + self.tcp_keepalive = Some(tcp_keepalive); + self + } + /// Set the maximum number of underlying database connections. pub fn max_concurrency(&mut self, value: usize) -> &mut Self { self.max_concurrency = Some(value); @@ -869,17 +912,17 @@ impl Builder { .pem_certificates .clone() .or_else(|| creds.and_then(|c| c.tls_ca.clone())), - + tcp_keepalive: self.tcp_keepalive.unwrap_or_default().as_keepalive(), // Pool configuration max_concurrency: self.max_concurrency, // Temporary placeholders verifier: Arc::new(tls::NullVerifier), - client_security: self.client_security.unwrap_or(ClientSecurity::Default), + client_security: self.client_security.unwrap_or_default(), tls_security: self .tls_security .or_else(|| creds.map(|c| c.tls_security)) - .unwrap_or(TlsSecurity::Default), + .unwrap_or_default(), }; cfg.verifier = cfg.make_verifier(cfg.compute_tls_security()?); @@ -1504,9 +1547,9 @@ impl Builder { extra_dsn_query_args: HashMap::new(), creds_file_outdated: false, pem_certificates: self.pem_certificates.clone(), - client_security: self.client_security.unwrap_or(ClientSecurity::Default), - tls_security: self.tls_security.unwrap_or(TlsSecurity::Default), - + client_security: self.client_security.unwrap_or_default(), + tls_security: self.tls_security.unwrap_or_default(), + tcp_keepalive: self.tcp_keepalive.unwrap_or_default().as_keepalive(), // Pool configuration max_concurrency: self.max_concurrency, diff --git a/edgedb-tokio/src/credentials.rs b/edgedb-tokio/src/credentials.rs index 7d14c955..8ecb8f26 100644 --- a/edgedb-tokio/src/credentials.rs +++ b/edgedb-tokio/src/credentials.rs @@ -7,7 +7,7 @@ use serde::{ser, Deserialize, Serialize}; use crate::errors::{Error, ErrorKind}; /// TLS Client Security Mode -#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)] +#[derive(Default, Debug, Clone, Copy, PartialEq, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum TlsSecurity { /// Allow any certificate for TLS connection @@ -23,6 +23,7 @@ pub enum TlsSecurity { Strict, /// If there is a specific certificate in credentials, do not check /// the host name, otherwise use `Strict` mode + #[default] Default, } diff --git a/edgedb-tokio/src/lib.rs b/edgedb-tokio/src/lib.rs index 9b796958..6cb60db8 100644 --- a/edgedb-tokio/src/lib.rs +++ b/edgedb-tokio/src/lib.rs @@ -140,7 +140,7 @@ pub mod tutorial; pub use edgedb_derive::{ConfigDelta, GlobalsDelta, Queryable}; -pub use builder::{Builder, ClientSecurity, Config, InstanceName}; +pub use builder::{Builder, ClientSecurity, Config, InstanceName, TcpKeepalive}; pub use client::Client; pub use credentials::TlsSecurity; pub use errors::Error; diff --git a/edgedb-tokio/src/raw/connection.rs b/edgedb-tokio/src/raw/connection.rs index f92278f2..e0d071e8 100644 --- a/edgedb-tokio/src/raw/connection.rs +++ b/edgedb-tokio/src/raw/connection.rs @@ -11,6 +11,7 @@ use bytes::{Bytes, BytesMut}; use rand::{thread_rng, Rng}; use rustls::pki_types::DnsName; use scram::ScramClient; +use socket2::TcpKeepalive; use tls_api::TlsConnectorBuilder; use tls_api::{TlsConnector, TlsConnectorBox, TlsStream, TlsStreamDyn}; use tls_api_not_tls::TlsConnector as PlainConnector; @@ -335,6 +336,23 @@ async fn connect3(cfg: &Config, tls: &TlsConnectorBox) -> Result Cow::from(server_name), None => { diff --git a/flake.nix b/flake.nix index a8adc065..d0a2a734 100644 --- a/flake.nix +++ b/flake.nix @@ -65,7 +65,7 @@ buildInputs = [ (fenix_pkgs.toolchainOf { channel = "beta"; - sha256 = "sha256-H1BZtppFoMkxdDQ6ZVbTSg9PoKzkvsEbSSPIoB55t1w="; + sha256 = "sha256-WtTNSmxfoiHJEwCUnuDNfRNBZjNrzdBV02Hikw+YE+s="; }).defaultToolchain ] ++ common; };