Skip to content

Commit

Permalink
Set keep-alive on TCP sockets (#330)
Browse files Browse the repository at this point in the history
Migrations and other long-running operations may leave a TCP connection active but without any traffic.

To prevent the operating systems on either end from closing the connections (or any other system that monitors for activity on TCP streams to determine when they should be closed), we can set the keep-alive flag on the client.
  • Loading branch information
mmastrac authored Jun 27, 2024
1 parent 0937298 commit 89662fd
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 10 deletions.
1 change: 1 addition & 0 deletions edgedb-tokio/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ once_cell = "1.9.0"
tokio-stream = { version = "0.1.11", optional = true }
base64 = "0.22.1"
crc16 = "0.4.0"
socket2 = "0.5"

[target.'cfg(target_family="unix")'.dev-dependencies]
command-fds = "0.3.0"
Expand Down
57 changes: 50 additions & 7 deletions edgedb-tokio/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ use crate::tls;

pub const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
pub const DEFAULT_WAIT: Duration = Duration::from_secs(30);
pub const DEFAULT_TCP_KEEPALIVE: Duration = Duration::from_secs(60);
pub const DEFAULT_POOL_SIZE: usize = 10;
pub const DEFAULT_HOST: &str = "localhost";
pub const DEFAULT_PORT: u16 = 5656;
Expand All @@ -44,13 +45,14 @@ static PORT_WARN: std::sync::Once = std::sync::Once::new();
type Verifier = Arc<dyn ServerCertVerifier>;

/// Client security mode.
#[derive(Debug, Clone, Copy)]
#[derive(Default, Debug, Clone, Copy)]
pub enum ClientSecurity {
/// Disable security checks
InsecureDevMode,
/// Always verify domain an certificate
Strict,
/// Verify domain only if no specific certificate is configured
#[default]
Default,
}

Expand All @@ -61,6 +63,28 @@ pub enum CloudCerts {
Local,
}

/// TCP keepalive configuration.
#[derive(Default, Debug, Clone, Copy)]
pub enum TcpKeepalive {
/// Disable TCP keepalive probes.
Disabled,
/// Explicit duration between TCP keepalive probes.
Explicit(Duration),
/// Default: 60 seconds.
#[default]
Default,
}

impl TcpKeepalive {
fn as_keepalive(&self) -> Option<Duration> {
match self {
TcpKeepalive::Disabled => None,
TcpKeepalive::Default => Some(DEFAULT_TCP_KEEPALIVE),
TcpKeepalive::Explicit(duration) => Some(*duration),
}
}
}

/// A builder used to create connections.
#[derive(Debug, Clone, Default)]
pub struct Builder {
Expand All @@ -83,6 +107,7 @@ pub struct Builder {
wait_until_available: Option<Duration>,
admin: bool,
connect_timeout: Option<Duration>,
tcp_keepalive: Option<TcpKeepalive>,
secret_key: Option<String>,
cloud_profile: Option<String>,

Expand Down Expand Up @@ -121,6 +146,9 @@ pub(crate) struct ConfigInner {
#[allow(dead_code)] // used only on unstable feature
pub creds_file_outdated: bool,

// Whether to set TCP keepalive or not
pub tcp_keepalive: Option<Duration>,

// Pool configuration
pub max_concurrency: Option<usize>,

Expand Down Expand Up @@ -785,6 +813,21 @@ impl Builder {
self
}

/// Sets the TCP keepalive interval and time for the database connection to
/// ensure that the remote end of the connection is still alive, and to
/// inform any network intermediaries that this connection is not idle. By
/// default, a keepalive probe will be sent once every 60 seconds once the
/// connection has been idle for 60 seconds.
///
/// Note: If the connection is not made over a TCP socket, this value will
/// be unused. If the current platform does not support explicit TCP
/// keep-alive intervals on the socket, keepalives will be enabled and the
/// operating-system default for the intervals will be used.
pub fn tcp_keepalive(&mut self, tcp_keepalive: TcpKeepalive) -> &mut Self {
self.tcp_keepalive = Some(tcp_keepalive);
self
}

/// Set the maximum number of underlying database connections.
pub fn max_concurrency(&mut self, value: usize) -> &mut Self {
self.max_concurrency = Some(value);
Expand Down Expand Up @@ -869,17 +912,17 @@ impl Builder {
.pem_certificates
.clone()
.or_else(|| creds.and_then(|c| c.tls_ca.clone())),

tcp_keepalive: self.tcp_keepalive.unwrap_or_default().as_keepalive(),
// Pool configuration
max_concurrency: self.max_concurrency,

// Temporary placeholders
verifier: Arc::new(tls::NullVerifier),
client_security: self.client_security.unwrap_or(ClientSecurity::Default),
client_security: self.client_security.unwrap_or_default(),
tls_security: self
.tls_security
.or_else(|| creds.map(|c| c.tls_security))
.unwrap_or(TlsSecurity::Default),
.unwrap_or_default(),
};

cfg.verifier = cfg.make_verifier(cfg.compute_tls_security()?);
Expand Down Expand Up @@ -1504,9 +1547,9 @@ impl Builder {
extra_dsn_query_args: HashMap::new(),
creds_file_outdated: false,
pem_certificates: self.pem_certificates.clone(),
client_security: self.client_security.unwrap_or(ClientSecurity::Default),
tls_security: self.tls_security.unwrap_or(TlsSecurity::Default),

client_security: self.client_security.unwrap_or_default(),
tls_security: self.tls_security.unwrap_or_default(),
tcp_keepalive: self.tcp_keepalive.unwrap_or_default().as_keepalive(),
// Pool configuration
max_concurrency: self.max_concurrency,

Expand Down
3 changes: 2 additions & 1 deletion edgedb-tokio/src/credentials.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use serde::{ser, Deserialize, Serialize};
use crate::errors::{Error, ErrorKind};

/// TLS Client Security Mode
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
#[derive(Default, Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum TlsSecurity {
/// Allow any certificate for TLS connection
Expand All @@ -23,6 +23,7 @@ pub enum TlsSecurity {
Strict,
/// If there is a specific certificate in credentials, do not check
/// the host name, otherwise use `Strict` mode
#[default]
Default,
}

Expand Down
2 changes: 1 addition & 1 deletion edgedb-tokio/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ pub mod tutorial;

pub use edgedb_derive::{ConfigDelta, GlobalsDelta, Queryable};

pub use builder::{Builder, ClientSecurity, Config, InstanceName};
pub use builder::{Builder, ClientSecurity, Config, InstanceName, TcpKeepalive};
pub use client::Client;
pub use credentials::TlsSecurity;
pub use errors::Error;
Expand Down
18 changes: 18 additions & 0 deletions edgedb-tokio/src/raw/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use bytes::{Bytes, BytesMut};
use rand::{thread_rng, Rng};
use rustls::pki_types::DnsName;
use scram::ScramClient;
use socket2::TcpKeepalive;
use tls_api::TlsConnectorBuilder;
use tls_api::{TlsConnector, TlsConnectorBox, TlsStream, TlsStreamDyn};
use tls_api_not_tls::TlsConnector as PlainConnector;
Expand Down Expand Up @@ -335,6 +336,23 @@ async fn connect3(cfg: &Config, tls: &TlsConnectorBox) -> Result<TlsStream, Erro
let conn = TcpStream::connect(addr)
.await
.map_err(ClientConnectionError::with_source)?;

// Set keep-alive on the socket, but don't fail if this isn't successful
if let Some(keepalive) = cfg.0.tcp_keepalive {
let sock = socket2::SockRef::from(&conn);
#[cfg(target_os = "openbsd")]
let res = sock.set_keepalive(true);
#[cfg(not(target_os = "openbsd"))]
let res = sock.set_tcp_keepalive(
&TcpKeepalive::new()
.with_interval(keepalive)
.with_time(keepalive),
);
if let Err(e) = res {
log::warn!("Failed to set TCP keepalive: {e:?}");
}
}

let host = match &cfg.0.tls_server_name {
Some(server_name) => Cow::from(server_name),
None => {
Expand Down
2 changes: 1 addition & 1 deletion flake.nix
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
buildInputs = [
(fenix_pkgs.toolchainOf {
channel = "beta";
sha256 = "sha256-H1BZtppFoMkxdDQ6ZVbTSg9PoKzkvsEbSSPIoB55t1w=";
sha256 = "sha256-WtTNSmxfoiHJEwCUnuDNfRNBZjNrzdBV02Hikw+YE+s=";
}).defaultToolchain
] ++ common;
};
Expand Down

0 comments on commit 89662fd

Please sign in to comment.