diff --git a/edgedb-tokio/Cargo.toml b/edgedb-tokio/Cargo.toml index e97426de..260edced 100644 --- a/edgedb-tokio/Cargo.toml +++ b/edgedb-tokio/Cargo.toml @@ -50,6 +50,7 @@ once_cell = "1.9.0" tokio-stream = { version = "0.1.11", optional = true } base64 = "0.22.1" crc16 = "0.4.0" +socket2 = "0.5" [target.'cfg(target_family="unix")'.dev-dependencies] command-fds = "0.3.0" diff --git a/edgedb-tokio/src/builder.rs b/edgedb-tokio/src/builder.rs index 17ad0e59..49afe629 100644 --- a/edgedb-tokio/src/builder.rs +++ b/edgedb-tokio/src/builder.rs @@ -26,6 +26,7 @@ use crate::tls; pub const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(10); pub const DEFAULT_WAIT: Duration = Duration::from_secs(30); +pub const DEFAULT_TCP_KEEPALIVE: Duration = Duration::from_secs(60); pub const DEFAULT_POOL_SIZE: usize = 10; pub const DEFAULT_HOST: &str = "localhost"; pub const DEFAULT_PORT: u16 = 5656; @@ -44,13 +45,14 @@ static PORT_WARN: std::sync::Once = std::sync::Once::new(); type Verifier = Arc; /// Client security mode. -#[derive(Debug, Clone, Copy)] +#[derive(Default, Debug, Clone, Copy)] pub enum ClientSecurity { /// Disable security checks InsecureDevMode, /// Always verify domain an certificate Strict, /// Verify domain only if no specific certificate is configured + #[default] Default, } @@ -61,6 +63,28 @@ pub enum CloudCerts { Local, } +/// TCP keepalive configuration. +#[derive(Default, Debug, Clone, Copy)] +pub enum TcpKeepalive { + /// Disable TCP keepalive probes. + Disabled, + /// Explicit duration between TCP keepalive probes. + Explicit(Duration), + /// Default: 60 seconds. + #[default] + Default, +} + +impl TcpKeepalive { + fn as_keepalive(&self) -> Option { + match self { + TcpKeepalive::Disabled => None, + TcpKeepalive::Default => Some(DEFAULT_TCP_KEEPALIVE), + TcpKeepalive::Explicit(duration) => Some(*duration), + } + } +} + /// A builder used to create connections. #[derive(Debug, Clone, Default)] pub struct Builder { @@ -83,6 +107,7 @@ pub struct Builder { wait_until_available: Option, admin: bool, connect_timeout: Option, + tcp_keepalive: Option, secret_key: Option, cloud_profile: Option, @@ -121,6 +146,9 @@ pub(crate) struct ConfigInner { #[allow(dead_code)] // used only on unstable feature pub creds_file_outdated: bool, + // Whether to set TCP keepalive or not + pub tcp_keepalive: Option, + // Pool configuration pub max_concurrency: Option, @@ -785,6 +813,21 @@ impl Builder { self } + /// Sets the TCP keepalive interval and time for the database connection to + /// ensure that the remote end of the connection is still alive, and to + /// inform any network intermediaries that this connection is not idle. By + /// default, a keepalive probe will be sent once every 60 seconds once the + /// connection has been idle for 60 seconds. + /// + /// Note: If the connection is not made over a TCP socket, this value will + /// be unused. If the current platform does not support explicit TCP + /// keep-alive intervals on the socket, keepalives will be enabled and the + /// operating-system default for the intervals will be used. + pub fn tcp_keepalive(&mut self, tcp_keepalive: TcpKeepalive) -> &mut Self { + self.tcp_keepalive = Some(tcp_keepalive); + self + } + /// Set the maximum number of underlying database connections. pub fn max_concurrency(&mut self, value: usize) -> &mut Self { self.max_concurrency = Some(value); @@ -869,17 +912,17 @@ impl Builder { .pem_certificates .clone() .or_else(|| creds.and_then(|c| c.tls_ca.clone())), - + tcp_keepalive: self.tcp_keepalive.unwrap_or_default().as_keepalive(), // Pool configuration max_concurrency: self.max_concurrency, // Temporary placeholders verifier: Arc::new(tls::NullVerifier), - client_security: self.client_security.unwrap_or(ClientSecurity::Default), + client_security: self.client_security.unwrap_or_default(), tls_security: self .tls_security .or_else(|| creds.map(|c| c.tls_security)) - .unwrap_or(TlsSecurity::Default), + .unwrap_or_default(), }; cfg.verifier = cfg.make_verifier(cfg.compute_tls_security()?); @@ -1504,9 +1547,9 @@ impl Builder { extra_dsn_query_args: HashMap::new(), creds_file_outdated: false, pem_certificates: self.pem_certificates.clone(), - client_security: self.client_security.unwrap_or(ClientSecurity::Default), - tls_security: self.tls_security.unwrap_or(TlsSecurity::Default), - + client_security: self.client_security.unwrap_or_default(), + tls_security: self.tls_security.unwrap_or_default(), + tcp_keepalive: self.tcp_keepalive.unwrap_or_default().as_keepalive(), // Pool configuration max_concurrency: self.max_concurrency, diff --git a/edgedb-tokio/src/credentials.rs b/edgedb-tokio/src/credentials.rs index 7d14c955..8ecb8f26 100644 --- a/edgedb-tokio/src/credentials.rs +++ b/edgedb-tokio/src/credentials.rs @@ -7,7 +7,7 @@ use serde::{ser, Deserialize, Serialize}; use crate::errors::{Error, ErrorKind}; /// TLS Client Security Mode -#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)] +#[derive(Default, Debug, Clone, Copy, PartialEq, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum TlsSecurity { /// Allow any certificate for TLS connection @@ -23,6 +23,7 @@ pub enum TlsSecurity { Strict, /// If there is a specific certificate in credentials, do not check /// the host name, otherwise use `Strict` mode + #[default] Default, } diff --git a/edgedb-tokio/src/lib.rs b/edgedb-tokio/src/lib.rs index 9b796958..6cb60db8 100644 --- a/edgedb-tokio/src/lib.rs +++ b/edgedb-tokio/src/lib.rs @@ -140,7 +140,7 @@ pub mod tutorial; pub use edgedb_derive::{ConfigDelta, GlobalsDelta, Queryable}; -pub use builder::{Builder, ClientSecurity, Config, InstanceName}; +pub use builder::{Builder, ClientSecurity, Config, InstanceName, TcpKeepalive}; pub use client::Client; pub use credentials::TlsSecurity; pub use errors::Error; diff --git a/edgedb-tokio/src/raw/connection.rs b/edgedb-tokio/src/raw/connection.rs index f92278f2..e0d071e8 100644 --- a/edgedb-tokio/src/raw/connection.rs +++ b/edgedb-tokio/src/raw/connection.rs @@ -11,6 +11,7 @@ use bytes::{Bytes, BytesMut}; use rand::{thread_rng, Rng}; use rustls::pki_types::DnsName; use scram::ScramClient; +use socket2::TcpKeepalive; use tls_api::TlsConnectorBuilder; use tls_api::{TlsConnector, TlsConnectorBox, TlsStream, TlsStreamDyn}; use tls_api_not_tls::TlsConnector as PlainConnector; @@ -335,6 +336,23 @@ async fn connect3(cfg: &Config, tls: &TlsConnectorBox) -> Result Cow::from(server_name), None => { diff --git a/flake.nix b/flake.nix index a8adc065..d0a2a734 100644 --- a/flake.nix +++ b/flake.nix @@ -65,7 +65,7 @@ buildInputs = [ (fenix_pkgs.toolchainOf { channel = "beta"; - sha256 = "sha256-H1BZtppFoMkxdDQ6ZVbTSg9PoKzkvsEbSSPIoB55t1w="; + sha256 = "sha256-WtTNSmxfoiHJEwCUnuDNfRNBZjNrzdBV02Hikw+YE+s="; }).defaultToolchain ] ++ common; };