diff --git a/Cargo.lock b/Cargo.lock index 91fb7fd58..bfc0882f9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -34,6 +34,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" dependencies = [ "cfg-if", + "const-random", "once_cell", "version_check", "zerocopy", @@ -482,6 +483,26 @@ version = "0.9.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2459377285ad874054d797f3ccebf984978aa39129f6eafde5cdc8315b612f8" +[[package]] +name = "const-random" +version = "0.1.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87e00182fe74b066627d63b85fd550ac2998d4b0bd86bfed477a0ae4c7c71359" +dependencies = [ + "const-random-macro", +] + +[[package]] +name = "const-random-macro" +version = "0.1.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f9d839f2a20b0aee515dc581a6172f2321f96cab76c1a38a4c584a194955390e" +dependencies = [ + "getrandom 0.2.14", + "once_cell", + "tiny-keccak", +] + [[package]] name = "core-foundation" version = "0.9.4" @@ -573,6 +594,15 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "crossbeam-queue" +version = "0.3.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df0346b5d5e76ac2fe4e327c5fd1118d6be7c51dfb18f9b7922923f287471e35" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "crossbeam-utils" version = "0.8.19" @@ -787,6 +817,18 @@ dependencies = [ "miniz_oxide", ] +[[package]] +name = "flurry" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7874ce5eeafa5e546227f7c62911e586387bf03d6c9a45ac78aa1c3bc2fedb61" +dependencies = [ + "ahash", + "num_cpus", + "parking_lot", + "seize", +] + [[package]] name = "fnv" version = "1.0.7" @@ -1550,6 +1592,15 @@ version = "0.4.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c" +[[package]] +name = "lru" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3262e75e648fce39813cb56ac41f3c3e3f65217ebf3844d818d1f9398cfb0dc" +dependencies = [ + "hashbrown 0.14.3", +] + [[package]] name = "lru-cache" version = "0.1.2" @@ -1978,6 +2029,35 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" +[[package]] +name = "pingora-pool" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4569e3bef52b0abab239a5cf3287c71307615ca61be7fc7799d71fdaab33d81" +dependencies = [ + "crossbeam-queue", + "log", + "lru", + "parking_lot", + "pingora-timeout", + "thread_local", + "tokio", +] + +[[package]] +name = "pingora-timeout" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0be182194d34e1b28608eaa49ee0fb86e5b7ab1d21a1d7a2b4d402446fda47e1" +dependencies = [ + "futures", + "once_cell", + "parking_lot", + "pin-project-lite", + "thread_local", + "tokio", +] + [[package]] name = "plotters" version = "0.3.5" @@ -2574,6 +2654,16 @@ dependencies = [ "libc", ] +[[package]] +name = "seize" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e5739de653b129b0a59da381599cf17caf24bc586f6a797c52d3d6147c5b85a" +dependencies = [ + "num_cpus", + "once_cell", +] + [[package]] name = "semver" version = "1.0.22" @@ -2926,6 +3016,15 @@ dependencies = [ "time-core", ] +[[package]] +name = "tiny-keccak" +version = "2.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2c9d3793400a45f954c52e73d068316d76b6f4e36977e3fcebb13a2721e80237" +dependencies = [ + "crunchy", +] + [[package]] name = "tinytemplate" version = "1.2.1" @@ -3630,6 +3729,7 @@ dependencies = [ "diff", "drain", "duration-str", + "flurry", "futures", "futures-core", "futures-util", @@ -3658,6 +3758,8 @@ dependencies = [ "nix 0.28.0", "oid-registry", "once_cell", + "pin-project-lite", + "pingora-pool", "ppp", "pprof", "prometheus-client", diff --git a/Cargo.toml b/Cargo.toml index e46d367b2..d33ef5be4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -97,6 +97,9 @@ url = "2.2" x509-parser = { version = "0.16", default-features = false } tracing-log = "0.2" backoff = "0.4.0" +pin-project-lite = "0.2" +pingora-pool = "0.1.0" +flurry = "0.5.0" [target.'cfg(target_os = "linux")'.dependencies] netns-rs = "0.1" diff --git a/fuzz/Cargo.lock b/fuzz/Cargo.lock index 7c1963d2a..e02024e91 100644 --- a/fuzz/Cargo.lock +++ b/fuzz/Cargo.lock @@ -24,6 +24,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42cd52102d3df161c77a887b608d7a4897d7cc112886a9537b738a887a03aaff" dependencies = [ "cfg-if", + "const-random", "once_cell", "version_check", "zerocopy", @@ -364,6 +365,26 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "const-random" +version = "0.1.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87e00182fe74b066627d63b85fd550ac2998d4b0bd86bfed477a0ae4c7c71359" +dependencies = [ + "const-random-macro", +] + +[[package]] +name = "const-random-macro" +version = "0.1.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f9d839f2a20b0aee515dc581a6172f2321f96cab76c1a38a4c584a194955390e" +dependencies = [ + "getrandom 0.2.12", + "once_cell", + "tiny-keccak", +] + [[package]] name = "core-foundation" version = "0.9.4" @@ -444,6 +465,15 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "crossbeam-queue" +version = "0.3.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df0346b5d5e76ac2fe4e327c5fd1118d6be7c51dfb18f9b7922923f287471e35" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "crossbeam-utils" version = "0.8.19" @@ -613,6 +643,18 @@ version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" +[[package]] +name = "flurry" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7874ce5eeafa5e546227f7c62911e586387bf03d6c9a45ac78aa1c3bc2fedb61" +dependencies = [ + "ahash", + "num_cpus", + "parking_lot", + "seize", +] + [[package]] name = "fnv" version = "1.0.7" @@ -1302,6 +1344,15 @@ version = "0.4.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f" +[[package]] +name = "lru" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3262e75e648fce39813cb56ac41f3c3e3f65217ebf3844d818d1f9398cfb0dc" +dependencies = [ + "hashbrown 0.14.3", +] + [[package]] name = "lru-cache" version = "0.1.2" @@ -1641,6 +1692,35 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" +[[package]] +name = "pingora-pool" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4569e3bef52b0abab239a5cf3287c71307615ca61be7fc7799d71fdaab33d81" +dependencies = [ + "crossbeam-queue", + "log", + "lru", + "parking_lot", + "pingora-timeout", + "thread_local", + "tokio", +] + +[[package]] +name = "pingora-timeout" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0be182194d34e1b28608eaa49ee0fb86e5b7ab1d21a1d7a2b4d402446fda47e1" +dependencies = [ + "futures", + "once_cell", + "parking_lot", + "pin-project-lite", + "thread_local", + "tokio", +] + [[package]] name = "plotters" version = "0.3.5" @@ -2207,6 +2287,16 @@ dependencies = [ "libc", ] +[[package]] +name = "seize" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e5739de653b129b0a59da381599cf17caf24bc586f6a797c52d3d6147c5b85a" +dependencies = [ + "num_cpus", + "once_cell", +] + [[package]] name = "semver" version = "1.0.21" @@ -2469,6 +2559,15 @@ dependencies = [ "time-core", ] +[[package]] +name = "tiny-keccak" +version = "2.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2c9d3793400a45f954c52e73d068316d76b6f4e36977e3fcebb13a2721e80237" +dependencies = [ + "crunchy", +] + [[package]] name = "tinytemplate" version = "1.2.1" @@ -3165,6 +3264,7 @@ dependencies = [ "chrono", "drain", "duration-str", + "flurry", "futures", "futures-core", "futures-util", @@ -3189,6 +3289,8 @@ dependencies = [ "netns-rs", "nix 0.28.0", "once_cell", + "pin-project-lite", + "pingora-pool", "ppp", "pprof", "prometheus-client", diff --git a/src/admin.rs b/src/admin.rs index 35c43e2f2..686709927 100644 --- a/src/admin.rs +++ b/src/admin.rs @@ -395,7 +395,11 @@ fn change_log_level(reset: bool, level: &str) -> Response> { async fn handle_jemalloc_pprof_heapgen( _req: Request, ) -> anyhow::Result>> { - let mut prof_ctl = jemalloc_pprof::PROF_CTL.as_ref()?.lock().await; + let mut prof_ctl = jemalloc_pprof::PROF_CTL + .as_ref() + .expect("should init") + .lock() + .await; if !prof_ctl.activated() { return Ok(Response::builder() .status(hyper::StatusCode::INTERNAL_SERVER_ERROR) @@ -405,7 +409,7 @@ async fn handle_jemalloc_pprof_heapgen( let pprof = prof_ctl.dump_pprof()?; Ok(Response::builder() .status(hyper::StatusCode::OK) - .body(Bytes::from(pprof?).into()) + .body(Bytes::from(pprof).into()) .expect("builder with known status code should not fail")) } diff --git a/src/config.rs b/src/config.rs index 0af5ca85a..216739743 100644 --- a/src/config.rs +++ b/src/config.rs @@ -49,6 +49,8 @@ const CA_ADDRESS: &str = "CA_ADDRESS"; const SECRET_TTL: &str = "SECRET_TTL"; const FAKE_CA: &str = "FAKE_CA"; const ZTUNNEL_WORKER_THREADS: &str = "ZTUNNEL_WORKER_THREADS"; +const POOL_MAX_STREAMS_PER_CONNECTION: &str = "POOL_MAX_STREAMS_PER_CONNECTION"; +const POOL_UNUSED_RELEASE_TIMEOUT: &str = "POOL_UNUSED_RELEASE_TIMEOUT"; const ENABLE_ORIG_SRC: &str = "ENABLE_ORIG_SRC"; const PROXY_CONFIG: &str = "PROXY_CONFIG"; @@ -63,6 +65,8 @@ const DEFAULT_SELFTERM_DEADLINE: Duration = Duration::from_secs(5); const DEFAULT_CLUSTER_ID: &str = "Kubernetes"; const DEFAULT_CLUSTER_DOMAIN: &str = "cluster.local"; const DEFAULT_TTL: Duration = Duration::from_secs(60 * 60 * 24); // 24 hours +const DEFAULT_POOL_UNUSED_RELEASE_TIMEOUT: Duration = Duration::from_secs(60 * 5); // 5 minutes +const DEFAULT_POOL_MAX_STREAMS_PER_CONNECTION: u16 = 100; //Go: 100, Hyper: 200, Envoy: 2147483647 (lol), Spec recommended minimum 100 const DEFAULT_INPOD_MARK: u32 = 1337; @@ -125,6 +129,21 @@ pub struct Config { pub connection_window_size: u32, pub frame_size: u32, + // The limit of how many streams a single HBONE pool connection will be limited to, before + // spawning a new conn rather than reusing an existing one, even to a dest that already has an open connection. + // + // This can be used to effect flow control for "connection storms" when workload clients + // (such as loadgen clients) open many connections all at once. + // + // Note that this will only be checked when a *new* connection + // is requested from the pool, and not on every *stream* queued on that connection. + // So if you request a single connection from a pool configured wiht a max streamcount of 200, + // and queue 500 streams on it, you will still exceed this limit and are at the mercy of hyper's + // default stream queuing. + pub pool_max_streams_per_conn: u16, + + pub pool_unused_release_timeout: Duration, + pub socks5_addr: Option, pub admin_addr: SocketAddr, pub stats_addr: SocketAddr, @@ -321,6 +340,16 @@ pub fn construct_config(pc: ProxyConfig) -> Result { .get(DNS_CAPTURE_METADATA) .map_or(false, |value| value.to_lowercase() == "true"), + pool_max_streams_per_conn: parse_default( + POOL_MAX_STREAMS_PER_CONNECTION, + DEFAULT_POOL_MAX_STREAMS_PER_CONNECTION, + )?, + + pool_unused_release_timeout: match parse::(POOL_UNUSED_RELEASE_TIMEOUT)? { + Some(ttl) => duration_str::parse(ttl).unwrap_or(DEFAULT_POOL_UNUSED_RELEASE_TIMEOUT), + None => DEFAULT_POOL_UNUSED_RELEASE_TIMEOUT, + }, + window_size: 4 * 1024 * 1024, connection_window_size: 4 * 1024 * 1024, frame_size: 1024 * 1024, diff --git a/src/identity/manager.rs b/src/identity/manager.rs index f054813e1..4b80ce826 100644 --- a/src/identity/manager.rs +++ b/src/identity/manager.rs @@ -117,6 +117,8 @@ impl fmt::Display for Identity { } } +// TODO we shouldn't have a "default identity" outside of tests +// #[cfg(test)] impl Default for Identity { fn default() -> Self { const TRUST_DOMAIN: &str = "cluster.local"; diff --git a/src/proxy.rs b/src/proxy.rs index 610a3b98e..080d34ae0 100644 --- a/src/proxy.rs +++ b/src/proxy.rs @@ -49,7 +49,7 @@ mod inbound_passthrough; #[allow(non_camel_case_types)] pub mod metrics; mod outbound; -mod pool; +pub mod pool; mod socks5; mod util; @@ -105,11 +105,11 @@ pub(super) struct ProxyInputs { hbone_port: u16, pub state: DemandProxyState, metrics: Arc, - pool: pool::Pool, socket_factory: Arc, proxy_workload_info: Option>, } +#[allow(clippy::too_many_arguments)] impl ProxyInputs { pub fn new( cfg: config::Config, @@ -126,7 +126,6 @@ impl ProxyInputs { cert_manager, metrics, connection_manager, - pool: pool::Pool::new(), hbone_port: 0, socket_factory, proxy_workload_info: proxy_workload_info.map(Arc::new), @@ -143,15 +142,16 @@ impl Proxy { drain: Watch, ) -> Result { let metrics = Arc::new(metrics); + let socket_factory = Arc::new(DefaultSocketFactory); + let pi = ProxyInputs { cfg, state, cert_manager, connection_manager: ConnectionManager::default(), metrics, - pool: pool::Pool::new(), hbone_port: 0, - socket_factory: Arc::new(DefaultSocketFactory), + socket_factory, proxy_workload_info: None, }; Self::from_inputs(pi, drain).await @@ -245,10 +245,13 @@ pub enum Error { AuthorizationPolicyRejection, #[error("pool is already connecting")] - PoolAlreadyConnecting, + WorkloadHBONEPoolAlreadyConnecting, + + #[error("connection streams maxed out")] + WorkloadHBONEPoolConnStreamsMaxed, - #[error("pool: {0}")] - Pool(#[from] hyper_util::client::legacy::pool::Error), + #[error("pool draining")] + WorkloadHBONEPoolDraining, #[error("{0}")] Generic(Box), diff --git a/src/proxy/inbound.rs b/src/proxy/inbound.rs index 379a72326..b3d511723 100644 --- a/src/proxy/inbound.rs +++ b/src/proxy/inbound.rs @@ -17,7 +17,7 @@ use std::fmt; use std::fmt::{Display, Formatter}; use std::net::SocketAddr; use std::sync::Arc; -use std::time::Instant; +use std::time::{Duration, Instant}; use bytes::Bytes; use drain::Watch; @@ -119,6 +119,11 @@ impl Inbound { let serve = crate::hyper_util::http2_server() .initial_stream_window_size(self.pi.cfg.window_size) .initial_connection_window_size(self.pi.cfg.connection_window_size) + // well behaved clients should close connections. + // not all clients are well-behaved. This will prune + // connections when the client is not responding, to keep + // us from holding many stale conns from deceased clients + .keep_alive_interval(Some(Duration::from_secs(10))) .max_frame_size(self.pi.cfg.frame_size) // 64KB max; default is 16MB driven from Golang's defaults // Since we know we are going to recieve a bounded set of headers, more is overkill. diff --git a/src/proxy/outbound.rs b/src/proxy/outbound.rs index 096fc5b12..911b318f4 100644 --- a/src/proxy/outbound.rs +++ b/src/proxy/outbound.rs @@ -20,7 +20,6 @@ use std::time::Instant; use bytes::Bytes; use drain::Watch; use http_body_util::Empty; -use hyper::client::conn::http2; use hyper::header::FORWARDED; use tokio::net::{TcpListener, TcpStream}; @@ -37,7 +36,7 @@ use crate::proxy::{util, Error, ProxyInputs, TraceParent, BAGGAGE_HEADER, TRACEP use crate::state::service::ServiceDescription; use crate::state::workload::gatewayaddress::Destination; use crate::state::workload::{address::Address, NetworkAddress, Protocol, Workload}; -use crate::{hyper_util, proxy, socket}; +use crate::{proxy, socket}; pub struct Outbound { pi: ProxyInputs, @@ -79,18 +78,24 @@ impl Outbound { // // So use a drain to nuke tasks that might be stuck sending. let (sub_drain_signal, sub_drain) = drain::channel(); + + let pool = proxy::pool::WorkloadHBONEPool::new( + self.pi.cfg.clone(), + self.pi.socket_factory.clone(), + self.pi.cert_manager.clone(), + ); let accept = async move { loop { // Asynchronously wait for an inbound socket. let socket = self.listener.accept().await; let start_outbound_instant = Instant::now(); let outbound_drain = sub_drain.clone(); - let outer_conn_drain = sub_drain.clone(); match socket { Ok((stream, _remote)) => { let mut oc = OutboundConnection { pi: self.pi.clone(), id: TraceParent::new(), + pool: pool.clone(), }; let span = info_span!("outbound", id=%oc.id); tokio::spawn( @@ -101,7 +106,7 @@ impl Outbound { _ = outbound_drain.signaled() => { debug!("outbound drain signaled"); } - _ = oc.proxy(stream, outer_conn_drain.clone()) => {} + _ = oc.proxy(stream) => {} } debug!(dur=?start_outbound_instant.elapsed(), id=%oc.id, "outbound spawn DONE"); }) @@ -135,21 +140,16 @@ impl Outbound { pub(super) struct OutboundConnection { pub(super) pi: ProxyInputs, pub(super) id: TraceParent, + pub(super) pool: proxy::pool::WorkloadHBONEPool, } impl OutboundConnection { - async fn proxy(&mut self, source_stream: TcpStream, outer_conn_drain: Watch) { + async fn proxy(&mut self, source_stream: TcpStream) { let source_addr = socket::to_canonical(source_stream.peer_addr().expect("must receive peer addr")); let dst_addr = socket::orig_dst_addr_or_default(&source_stream); - self.proxy_to( - source_stream, - source_addr, - dst_addr, - false, - Some(outer_conn_drain), - ) - .await; + self.proxy_to(source_stream, source_addr, dst_addr, false) + .await; } // this is a cancellable outbound proxy. If `out_drain` is a Watch drain, will resolve @@ -169,16 +169,15 @@ impl OutboundConnection { ) { match out_drain { Some(drain) => { - let outer_conn_drain = drain.clone(); tokio::select! { _ = drain.signaled() => { - info!("socks drain signaled"); + info!("drain signaled"); } - res = self.proxy_to(stream, remote_addr, orig_dst_addr, block_passthrough, Some(outer_conn_drain)) => res + res = self.proxy_to(stream, remote_addr, orig_dst_addr, block_passthrough) => res } } None => { - self.proxy_to(stream, remote_addr, orig_dst_addr, block_passthrough, None) + self.proxy_to(stream, remote_addr, orig_dst_addr, block_passthrough) .await; } } @@ -190,7 +189,6 @@ impl OutboundConnection { source_addr: SocketAddr, dest_addr: SocketAddr, block_passthrough: bool, - outer_conn_drain: Option, ) { let start = Instant::now(); @@ -249,14 +247,8 @@ impl OutboundConnection { let res = match req.protocol { Protocol::HBONE => { - self.proxy_to_hbone( - &mut source_stream, - source_addr, - outer_conn_drain, - &req, - &result_tracker, - ) - .await + self.proxy_to_hbone(&mut source_stream, source_addr, &req, &result_tracker) + .await } Protocol::TCP => { self.proxy_to_tcp(&mut source_stream, &req, &result_tracker) @@ -270,7 +262,6 @@ impl OutboundConnection { &mut self, stream: &mut TcpStream, remote_addr: SocketAddr, - outer_conn_drain: Option, req: &Request, connection_stats: &ConnectionResult, ) -> Result<(), Error> { @@ -296,75 +287,20 @@ impl OutboundConnection { ); let dst_identity = allowed_sans; - let pool_key = pool::Key { + let pool_key = pool::WorkloadKey { src_id: req.source.identity(), dst_id: dst_identity.clone(), src: remote_addr.ip(), dst: req.gateway, }; - // Setup our connection future. This won't always run if we have an existing connection - // in the pool. - let connect = async { - let mut builder = http2::Builder::new(hyper_util::TokioExecutor); - let builder = builder - .initial_stream_window_size(self.pi.cfg.window_size) - .max_frame_size(self.pi.cfg.frame_size) - .initial_connection_window_size(self.pi.cfg.connection_window_size); - - let local = self - .pi - .cfg - .enable_original_source - .unwrap_or_default() - .then_some(remote_addr.ip()); - let id = &req.source.identity(); - let cert = self.pi.cert_manager.fetch_certificate(id).await?; - let connector = cert.outbound_connector(dst_identity)?; - let tcp_stream = - super::freebind_connect(local, req.gateway, self.pi.socket_factory.as_ref()) - .await?; - tcp_stream.set_nodelay(true)?; // TODO: this is backwards of expectations - let tls_stream = connector.connect(tcp_stream).await?; - let (request_sender, connection) = builder - .handshake(::hyper_util::rt::TokioIo::new(tls_stream)) - .await - .map_err(Error::HttpHandshake)?; - - // spawn a task to poll the connection and drive the HTTP state - // if we got a drain for that connection, respect it in a race - match outer_conn_drain { - Some(conn_drain) => { - tokio::spawn(async move { - tokio::select! { - _ = conn_drain.signaled() => { - debug!("draining outer HBONE connection"); - } - res = connection=> { - match res { - Err(e) => { - error!("Error in HBONE connection handshake: {:?}", e); - } - Ok(_) => { - debug!("done with HBONE connection handshake: {:?}", res); - } - } - } - } - }); - } - None => { - tokio::spawn(async move { - if let Err(e) = connection.await { - error!("Error in HBONE connection handshake: {:?}", e); - } - }); - } - } - - Ok(request_sender) - }; - let mut connection = self.pi.pool.connect(pool_key.clone(), connect).await?; + debug!("outbound - connection get START"); + let mut connection = self + .pool + .connect(pool_key.clone()) + .instrument(trace_span!("get pool conn")) + .await?; + debug!("outbound - connection get END"); let mut f = http_types::proxies::Forwarded::new(); f.add_for(remote_addr.to_string()); @@ -386,11 +322,15 @@ impl OutboundConnection { // There are scenarios (upstream hangup, etc) where this "send" will simply get stuck. // As in, stream processing deadlocks, and `send_request` never resolves to anything. // Probably related to https://github.com/hyperium/hyper/issues/3623 - let response = connection.send_request(request).await?; + let response = connection + .send_request(request) + .instrument(trace_span!("send pool conn")) + .await?; debug!("outbound - connection send END"); let code = response.status(); if code != 200 { + debug!("outbound - connection send FAIL: {code}"); return Err(Error::HttpStatus(code)); } let upgraded = hyper::upgrade::on(response).await?; @@ -708,19 +648,22 @@ mod tests { XdsAddressType::Workload(wl) => new_proxy_state(&[source, waypoint, wl], &[], &[]), XdsAddressType::Service(svc) => new_proxy_state(&[source, waypoint], &[svc], &[]), }; + + let sock_fact = std::sync::Arc::new(crate::proxy::DefaultSocketFactory); + let cert_mgr = identity::mock::new_secret_manager(Duration::from_secs(10)); let outbound = OutboundConnection { pi: ProxyInputs { cert_manager: identity::mock::new_secret_manager(Duration::from_secs(10)), state, hbone_port: 15008, - cfg, + cfg: cfg.clone(), metrics: test_proxy_metrics(), - pool: pool::Pool::new(), - socket_factory: std::sync::Arc::new(crate::proxy::DefaultSocketFactory), + socket_factory: sock_fact.clone(), proxy_workload_info: None, connection_manager: ConnectionManager::default(), }, id: TraceParent::new(), + pool: pool::WorkloadHBONEPool::new(cfg, sock_fact, cert_mgr.clone()), }; let req = outbound diff --git a/src/proxy/pool.rs b/src/proxy/pool.rs index f274617d2..502a479ae 100644 --- a/src/proxy/pool.rs +++ b/src/proxy/pool.rs @@ -12,241 +12,1458 @@ // See the License for the specific language governing permissions and // limitations under the License. +use super::{Error, SocketFactory}; use bytes::Bytes; -use futures::pin_mut; -use futures_util::future; -use futures_util::future::Either; use http_body_util::Empty; use hyper::body::Incoming; use hyper::client::conn::http2; use hyper::http::{Request, Response}; -use hyper_util::client::legacy::pool; -use hyper_util::client::legacy::pool::{Pool as HyperPool, Poolable, Pooled, Reservation}; -use hyper_util::rt::TokioTimer; +use std::time::Duration; + +use std::collections::hash_map::DefaultHasher; use std::future::Future; +use std::hash::{Hash, Hasher}; use std::net::IpAddr; use std::net::SocketAddr; -use std::time::Duration; -use tracing::debug; +use std::sync::atomic::{AtomicI32, AtomicU16, Ordering}; +use std::sync::Arc; + +use tokio::sync::watch; +use tokio::task; + +use tokio::sync::Mutex; +use tracing::{debug, error, trace}; + +use crate::config; +use crate::identity::{Identity, SecretManager}; -use crate::identity::Identity; -use crate::proxy::Error; +use flurry; +use pingora_pool; + +// A relatively nonstandard HTTP/2 connection pool designed to allow multiplexing proxied workload connections +// over a (smaller) number of HTTP/2 mTLS tunnels. +// +// The following invariants apply to this pool: +// - Every workload (inpod mode) gets its own connpool. +// - Every unique src/dest key gets their own dedicated connections inside the pool. +// - Every unique src/dest key gets 1-n dedicated connections, where N is (currently) unbounded but practically limited +// by flow control throttling. #[derive(Clone)] -pub struct Pool { - pool: HyperPool, +pub struct WorkloadHBONEPool { + state: Arc, + pool_watcher: watch::Receiver, } -impl Pool { - pub fn new() -> Pool { - Self { - pool: HyperPool::new( - hyper_util::client::legacy::pool::Config { - idle_timeout: Some(Duration::from_secs(90)), - max_idle_per_host: std::usize::MAX, - }, - TokioExec, - Some(TokioTimer::new()), - ), - } - } +// PoolState is effectively the gnarly inner state stuff that needs thread/task sync, and should be wrapped in a Mutex. +struct PoolState { + pool_notifier: watch::Sender, // This is already impl clone? rustc complains that it isn't, tho + timeout_tx: watch::Sender, // This is already impl clone? rustc complains that it isn't, tho + // this is effectively just a convenience data type - a rwlocked hashmap with keying and LRU drops + // and has no actual hyper/http/connection logic. + connected_pool: Arc>, + // this must be an atomic/concurrent-safe list-of-locks, so we can lock per-key, not globally, and avoid holding up all conn attempts + established_conn_writelock: flurry::HashMap>>>, + close_pollers: futures::stream::FuturesUnordered>, + pool_unused_release_timeout: Duration, + // This is merely a counter to track the overall number of conns this pool spawns + // to ensure we get unique poolkeys-per-new-conn, it is not a limit + pool_global_conn_count: AtomicI32, + max_streamcount: u16, + spawner: ConnSpawner, } -#[derive(Clone)] -pub struct TokioExec; +struct ConnSpawner { + cfg: config::Config, + socket_factory: Arc, + cert_manager: Arc, + timeout_rx: watch::Receiver, +} + +// Does nothing but spawn new conns when asked +impl ConnSpawner { + async fn new_pool_conn( + &self, + key: WorkloadKey, + ) -> Result>, Error> { + debug!("spawning new pool conn for key {:#?}", key); + let clone_key = key.clone(); + let mut c_builder = http2::Builder::new(crate::hyper_util::TokioExecutor); + let builder = c_builder + .initial_stream_window_size(self.cfg.window_size) + .max_frame_size(self.cfg.frame_size) + .initial_connection_window_size(self.cfg.connection_window_size); + + let local = self + .cfg + .enable_original_source + .unwrap_or_default() + .then_some(key.src); + let cert = self.cert_manager.fetch_certificate(&key.src_id).await?; + let connector = cert.outbound_connector(key.dst_id)?; + let tcp_stream = + super::freebind_connect(local, key.dst, self.socket_factory.as_ref()).await?; + tcp_stream.set_nodelay(true)?; // TODO: this is backwards of expectations + let tls_stream = connector.connect(tcp_stream).await?; + trace!("connector connected, handshaking"); + let (request_sender, connection) = builder + .handshake(::hyper_util::rt::TokioIo::new(tls_stream)) + .await + .map_err(Error::HttpHandshake)?; + + // spawn a task to poll the connection and drive the HTTP state + // if we got a drain for that connection, respect it in a race + // it is important to have a drain here, or this connection will never terminate + let mut driver_drain = self.timeout_rx.clone(); + tokio::spawn(async move { + debug!("starting a connection driver for {:?}", clone_key); + tokio::select! { + _ = driver_drain.changed() => { + debug!("draining outer HBONE connection {:?}", clone_key); + } + res = connection=> { + match res { + Err(e) => { + error!("Error in HBONE connection handshake: {:?}", e); + } + Ok(_) => { + debug!("done with HBONE connection handshake: {:?}", res); + } + } + } + } + }); -impl hyper::rt::Executor for TokioExec -where - F: std::future::Future + Send + 'static, - F::Output: Send + 'static, -{ - fn execute(&self, fut: F) { - tokio::spawn(fut); + Ok(request_sender) } } -#[derive(Debug, Clone)] -struct Client(http2::SendRequest>); +impl PoolState { + // This simply puts the connection back into the inner pool, + // and sets up a timed popper, which will resolve + // - when this reference is popped back out of the inner pool (doing nothing) + // - when this reference is evicted from the inner pool (doing nothing) + // - when the timeout_idler is drained (will pop) + // - when the timeout is hit (will pop) + // + // Idle poppers are safe to invoke if the conn they are popping is already gone + // from the inner queue, so we will start one for every insert, let them run or terminate on their own, + // and poll them to completion on shutdown - any duplicates from repeated checkouts/checkins of the same conn + // will simply resolve as a no-op in order. + // + // Note that "idle" in the context of this pool means "no one has asked for it or dropped it in X time, so prune it". + // + // Pruning the idle connection from the pool does not close it - it simply ensures the pool stops holding a ref. + // hyper self-closes client conns when all refs are dropped and streamcount is 0, so pool consumers must + // drop their checked out conns and/or terminate their streams as well. + // + // Note that this simply removes the client ref from this pool - if other things hold client/streamrefs refs, + // they must also drop those before the underlying connection is fully closed. + fn checkin_conn(&self, conn: ConnClient, pool_key: pingora_pool::ConnectionMeta) { + let (evict, pickup) = self.connected_pool.put(&pool_key, conn); + let rx = self.spawner.timeout_rx.clone(); + let pool_ref = self.connected_pool.clone(); + let pool_key_ref = pool_key.clone(); + let release_timeout = self.pool_unused_release_timeout; + self.close_pollers.push(tokio::spawn(async move { + debug!( + "starting an idle timeout for connection {:#?}", + pool_key_ref + ); + pool_ref + .idle_timeout(&pool_key_ref, release_timeout, evict, rx, pickup) + .await; + debug!( + "connection {:#?} was removed/checked out/timed out of the pool", + pool_key_ref + ) + })); + let _ = self.pool_notifier.send(true); + } + + // Since we are using a hash key to do lookup on the inner pingora pool, do a get guard + // to make sure what we pull out actually deep-equals the workload_key, to avoid *sigh* crossing the streams. + fn guarded_get( + &self, + hash_key: &u64, + workload_key: &WorkloadKey, + ) -> Result, Error> { + match self.connected_pool.get(hash_key) { + None => Ok(None), + Some(conn) => match Self::enforce_key_integrity(conn, workload_key) { + Err(e) => Err(e), + Ok(conn) => Ok(Some(conn)), + }, + } + } -impl Poolable for Client { - fn is_open(&self) -> bool { - self.0.is_ready() + // Just for safety's sake, since we are using a hash thanks to pingora NOT supporting arbitrary Eq, Hash + // types, do a deep equality test before returning the conn, returning an error if the conn's key does + // not equal the provided key + // + // this is a final safety check for collisions, we will throw up our hands and refuse to return the conn + fn enforce_key_integrity( + conn: ConnClient, + expected_key: &WorkloadKey, + ) -> Result { + match conn.is_for_workload(expected_key) { + Ok(()) => Ok(conn), + Err(e) => Err(e), + } } - fn reserve(self) -> Reservation { - let b = self.clone(); - let a = self; - Reservation::Shared(a, b) + // 1. Tries to get a writelock. + // 2. If successful, hold it, spawn a new connection, check it in, return a clone of it. + // 3. If not successful, return nothing. + // + // This is useful if we want to race someone else to the writelock to spawn a connection, + // and expect the losers to queue up and wait for the (singular) winner of the writelock + // + // This function should ALWAYS return a connection if it wins the writelock for the provided key. + // This function should NEVER return a connection if it does not win the writelock for the provided key. + // This function should ALWAYS propagate Error results to the caller + // + // It is important that the *initial* check here is authoritative, hence the locks, as + // we must know if this is a connection for a key *nobody* has tried to start yet + // (i.e. no writelock for our key in the outer map) + // or if other things have already established conns for this key (writelock for our key in the outer map). + // + // This is so we can backpressure correctly if 1000 tasks all demand a new connection + // to the same key at once, and not eagerly open 1000 tunnel connections. + async fn start_conn_if_win_writelock( + &self, + workload_key: &WorkloadKey, + pool_key: &pingora_pool::ConnectionMeta, + ) -> Result, Error> { + let inner_conn_lock = { + trace!("getting keyed lock out of lockmap"); + let guard = self.established_conn_writelock.guard(); + + let exist_conn_lock = self + .established_conn_writelock + .get(&pool_key.key, &guard) + .unwrap(); + trace!("got keyed lock out of lockmap"); + exist_conn_lock.as_ref().unwrap().clone() + }; + + trace!("attempting to win connlock for wl key {:#?}", workload_key); + + let inner_lock = inner_conn_lock.try_lock(); + match inner_lock { + Ok(_guard) => { + // BEGIN take inner writelock + debug!("nothing else is creating a conn and we won the lock, make one"); + let pool_conn = self.spawner.new_pool_conn(workload_key.clone()).await?; + let client = ConnClient { + sender: pool_conn, + stream_count: Arc::new(AtomicU16::new(0)), + stream_count_max: self.max_streamcount, + wl_key: workload_key.clone(), + }; + + debug!( + "checking in new conn for key {:#?} with pk {:#?}", + workload_key, pool_key + ); + self.checkin_conn(client.clone(), pool_key.clone()); + Ok(Some(client)) + // END take inner writelock + } + Err(_) => { + debug!( + "did not win connlock for wl key {:#?}, something else has it", + workload_key + ); + Ok(None) + } + } + } + + // Does an initial, naive check to see if we have a writelock inserted into the map for this key + // + // If we do, take the writelock for that key, clone (or create) a connection, check it back in, + // and return a cloned ref, then drop the writelock. + // + // Otherwise, return None. + // + // This function should ALWAYS return a connection if a writelock exists for the provided key. + // This function should NEVER return a connection if no writelock exists for the provided key. + // This function should ALWAYS propagate Error results to the caller + // + // It is important that the *initial* check here is authoritative, hence the locks, as + // we must know if this is a connection for a key *nobody* has tried to start yet + // (i.e. no writelock for our key in the outer map) + // or if other things have already established conns for this key (writelock for our key in the outer map). + // + // This is so we can backpressure correctly if 1000 tasks all demand a new connection + // to the same key at once, and not eagerly open 1000 tunnel connections. + async fn checkout_conn_under_writelock( + &self, + workload_key: &WorkloadKey, + pool_key: &pingora_pool::ConnectionMeta, + ) -> Result, Error> { + let found_conn = { + trace!("pool connect outer map - take guard"); + let guard = self.established_conn_writelock.guard(); + + trace!("pool connect outer map - check for keyed mutex"); + let exist_conn_lock = self.established_conn_writelock.get(&pool_key.key, &guard); + exist_conn_lock.and_then(|e_conn_lock| e_conn_lock.clone()) + }; + match found_conn { + Some(exist_conn_lock) => { + debug!( + "checkout - found mutex for pool key {:#?}, waiting for writelock", + pool_key + ); + let _conn_lock = exist_conn_lock.as_ref().lock().await; + + trace!( + "checkout - got writelock for conn with key {:#?} and hash {:#?}", + workload_key, + pool_key.key + ); + let result = match self.guarded_get(&pool_key.key, workload_key)? { + Some(e_conn) => { + trace!("checkout - got existing conn for key {:#?}", workload_key); + if e_conn.at_max_streamcount() { + debug!("got conn for wl key {:#?}, but streamcount is maxed, spawning new conn to replace using pool key {:#?}", workload_key, pool_key); + let pool_conn = + self.spawner.new_pool_conn(workload_key.clone()).await?; + let r_conn = ConnClient { + sender: pool_conn, + stream_count: Arc::new(AtomicU16::new(0)), + stream_count_max: self.max_streamcount, + wl_key: workload_key.clone(), + }; + self.checkin_conn(r_conn.clone(), pool_key.clone()); + Some(r_conn) + } else { + debug!("checking existing conn for key {:#?} back in", pool_key); + self.checkin_conn(e_conn.clone(), pool_key.clone()); + Some(e_conn) + } + } + None => { + trace!( + "checkout - no existing conn for key {:#?}, adding one", + workload_key + ); + let pool_conn = self.spawner.new_pool_conn(workload_key.clone()).await?; + let r_conn = ConnClient { + sender: pool_conn, + stream_count: Arc::new(AtomicU16::new(0)), + stream_count_max: self.max_streamcount, + wl_key: workload_key.clone(), + }; + self.checkin_conn(r_conn.clone(), pool_key.clone()); + Some(r_conn) + } + }; + + Ok(result) + } + None => Ok(None), + } } +} - fn can_share(&self) -> bool { - true // http2 always shares +// When the Arc-wrapped PoolState is finally dropped, trigger the drain, +// which will terminate all connection driver spawns, as well as cancel all outstanding eviction timeout spawns +impl Drop for PoolState { + fn drop(&mut self) { + debug!("poolstate dropping, stopping all connection drivers and cancelling all outstanding eviction timeout spawns"); + let _ = self.timeout_tx.send(true); } } -#[derive(PartialEq, Eq, Hash, Clone, Debug)] -pub struct Key { - pub src_id: Identity, - pub dst_id: Vec, - // In theory we can just use src,dst,node. However, the dst has a check that - // the L3 destination IP matches the HBONE IP. This could be loosened to just assert they are the same identity maybe. - pub dst: SocketAddr, - // Because we spoof the source IP, we need to key on this as well. Note: for in-pod its already per-pod - // pools anyways. - pub src: IpAddr, +impl WorkloadHBONEPool { + // Creates a new pool instance, which should be owned by a single proxied workload. + // The pool will watch the provided drain signal and drain itself when notified. + // Callers should then be safe to drop() the pool instance. + pub fn new( + cfg: crate::config::Config, + socket_factory: Arc, + cert_manager: Arc, + ) -> WorkloadHBONEPool { + let (timeout_tx, timeout_rx) = watch::channel(false); + let (timeout_send, timeout_recv) = watch::channel(false); + let max_count = cfg.pool_max_streams_per_conn; + let pool_duration = cfg.pool_unused_release_timeout; + + let spawner = ConnSpawner { + cfg, + socket_factory, + cert_manager, + timeout_rx: timeout_recv.clone(), + }; + + // This is merely a counter to track the overall number of conns this pool spawns + // to ensure we get unique poolkeys-per-new-conn, it is not a limit + debug!("constructing pool with {:#?} streams per conn", max_count); + + Self { + state: Arc::new(PoolState { + pool_notifier: timeout_tx, + timeout_tx: timeout_send, + // timeout_rx: timeout_recv, + // the number here is simply the number of unique src/dest keys + // the pool is expected to track before the inner hashmap resizes. + connected_pool: Arc::new(pingora_pool::ConnectionPool::new(500)), + established_conn_writelock: flurry::HashMap::new(), + close_pollers: futures::stream::FuturesUnordered::new(), + pool_unused_release_timeout: pool_duration, + pool_global_conn_count: AtomicI32::new(0), + max_streamcount: max_count, + spawner, + }), + pool_watcher: timeout_rx, + } + } + + // Obtain a pooled connection. Will prefer to retrieve an existing conn from the pool, but + // if none exist, or the existing conn is maxed out on streamcount, will spawn a new one, + // even if it is to the same dest+port. + // + // If many `connects` request a connection to the same dest at once, all will wait until exactly + // one connection is created, before deciding if they should create more or just use that one. + pub async fn connect(&mut self, workload_key: WorkloadKey) -> Result { + trace!("pool connect START"); + // TODO BML this may not be collision resistant, or a fast hash. It should be resistant enough for workloads tho. + // We are doing a deep-equals check at the end to mitigate any collisions, will see about bumping Pingora + let mut s = DefaultHasher::new(); + workload_key.hash(&mut s); + let hash_key = s.finish(); + let pool_key = pingora_pool::ConnectionMeta::new( + hash_key, + self.state + .pool_global_conn_count + .fetch_add(1, Ordering::SeqCst), + ); + // First, see if we can naively take an inner lock for our specific key, and get a connection. + // This should be the common case, except for the first establishment of a new connection/key. + // This will be done under outer readlock (nonexclusive)/inner keyed writelock (exclusive). + let existing_conn = self + .state + .checkout_conn_under_writelock(&workload_key, &pool_key) + .await?; + + // Early return, no need to do anything else + if existing_conn.is_some() { + debug!("initial attempt - found existing conn, done"); + return Ok(existing_conn.unwrap()); + } + + // We couldn't get a writelock for this key. This means nobody has tried to establish any conns for this key yet, + // So, we will take a nonexclusive readlock on the outer lockmap, and attempt to insert one. + // + // (if multiple threads try to insert one, only one will succeed.) + { + debug!( + "didn't find a connection for key {:#?}, making sure lockmap has entry", + hash_key + ); + let guard = self.state.established_conn_writelock.guard(); + match self.state.established_conn_writelock.try_insert( + hash_key, + Some(Arc::new(Mutex::new(()))), + &guard, + ) { + Ok(_) => { + debug!("inserting conn mutex for key {:#?} into lockmap", hash_key); + } + Err(_) => { + debug!("already have conn for key {:#?} in lockmap", hash_key); + } + } + } + + // If we get here, it means the following are true: + // 1. We have a guaranteed sharded mutex in the outer map for our current key + // 2. We can now, under readlock(nonexclusive) in the outer map, attempt to + // take the inner writelock for our specific key (exclusive). + // + // This doesn't block other tasks spawning connections against other keys, but DOES block other + // tasks spawning connections against THIS key - which is what we want. + + // NOTE: The inner, key-specific mutex is a tokio::async::Mutex, and not a stdlib sync mutex. + // these differ from the stdlib sync mutex in that they are (slightly) slower + // (they effectively sleep the current task) and they can be held over an await. + // The tokio docs (rightly) advise you to not use these, + // because holding a lock over an await is a great way to create deadlocks if the await you + // hold it over does not resolve. + // + // HOWEVER. Here we know this connection will either establish or timeout (or fail with error) + // and we WANT other tasks to go back to sleep if a task is already trying to create a new connection for this key. + // + // So the downsides are actually useful (we WANT task contention - + // to block other parallel tasks from trying to spawn a connection for this key if we are already doing so) + trace!("fallback attempt - trying win win connlock"); + let res = match self + .state + .start_conn_if_win_writelock(&workload_key, &pool_key) + .await? + { + Some(client) => client, + None => { + debug!("we didn't win the lock, something else is creating a conn, wait for it"); + // If we get here, it means the following are true: + // 1. We have a writelock in the outer map for this key (either we inserted, or someone beat us to it - but it's there) + // 2. We could not get the exclusive inner writelock to add a new conn for this key. + // 3. Someone else got the exclusive inner writelock, and is adding a new conn for this key. + // + // So, loop and wait for the pool_watcher to tell us a new conn was enpooled, + // so we can pull it out and check it. + loop { + match self.pool_watcher.changed().await { + Ok(_) => { + trace!( + "notified a new conn was enpooled, checking for hash {:#?}", + hash_key + ); + // Notifier fired, try and get a conn out for our key. + let existing_conn = self + .state + .checkout_conn_under_writelock(&workload_key, &pool_key) + .await?; + match existing_conn { + None => { + trace!("woke up on pool notification, but didn't find a conn for {:#?} yet", hash_key); + continue; + } + Some(e_conn) => { + debug!("found existing conn after waiting"); + break e_conn; + } + } + } + Err(_) => { + return Err(Error::WorkloadHBONEPoolDraining); + } + } + } + } + }; + Ok(res) + } } #[derive(Debug)] -pub struct Connection(Pooled); +// A sort of faux-client, that represents a single checked-out 'request sender' which might +// send requests over some underlying stream using some underlying http/2 client +pub struct ConnClient { + sender: http2::SendRequest>, + stream_count: Arc, // the current streamcount for this client conn. + stream_count_max: u16, // the max streamcount associated with this client. + // A WL key may have many clients, but every client has no more than one WL key + wl_key: WorkloadKey, // the WL key associated with this client. +} + +impl ConnClient { + pub fn at_max_streamcount(&self) -> bool { + let curr_count = self.stream_count.load(Ordering::Relaxed); + trace!("checking streamcount: {curr_count}"); + if curr_count >= self.stream_count_max { + return true; + } + false + } -impl Connection { pub fn send_request( &mut self, req: Request>, ) -> impl Future>> { - self.0 .0.send_request(req) + // TODO should we enforce streamcount per-sent-request? This would be slow. + self.stream_count.fetch_add(1, Ordering::Relaxed); + self.sender.send_request(req) + } + + pub fn is_for_workload(&self, wl_key: &WorkloadKey) -> Result<(), crate::proxy::Error> { + if !(self.wl_key == *wl_key) { + Err(crate::proxy::Error::Generic( + "fetched connection does not match workload key!".into(), + )) + } else { + Ok(()) + } } } -impl Pool { - pub async fn connect(&self, key: Key, connect: F) -> Result - where - F: Future>, Error>>, - { - let reuse_connection = self.pool.checkout(key.clone()); - - let connect_pool = async { - let ver = pool::Ver::Http2; - let Some(connecting) = self.pool.connecting(&key, ver) else { - // There is already an existing connection establishment in flight. - // Return an error so - return Err(Error::PoolAlreadyConnecting); - }; - let pc = Client(connect.await?); - let pooled = self.pool.pooled(connecting, pc); - Ok::<_, Error>(pooled) - }; - pin_mut!(connect_pool); - let request_sender: Pooled = - match future::select(reuse_connection, connect_pool).await { - // Checkout won. - Either::Left((Ok(conn), _)) => { - debug!(?key, "fetched existing connection"); - conn - } - // Checkout won, but had an error. - Either::Left((Err(err), connecting)) => match err { - // Checked out a closed connection. Just keep connecting then - pool::Error::CheckedOutClosedValue => connecting.await?, - // Some other error, bubble it up - _ => return Err(Error::Pool(err)), - }, - // Connect won, checkout can just be dropped. - Either::Right((Ok(request_sender), _checkout)) => { - debug!(?key, "established new connection"); - request_sender - } - // Connect won, checkout can just be dropped. - Either::Right((Err(err), checkout)) => { - debug!( - ?key, - "connect won, but wait for existing pooled connection to establish" - ); - match err { - // Connect won but we already had an in-flight connection, so use that. - Error::PoolAlreadyConnecting => checkout.await?, - // Some other connection error - err => return Err(err), - } - } - }; +// This is currently only for debugging +impl Drop for ConnClient { + fn drop(&mut self) { + trace!( + "dropping ConnClient for key {:#?} with streamcount: {:?} / {:?}", + self.wl_key, + self.stream_count, + self.stream_count_max + ) + } +} - Ok(Connection(request_sender)) +// This is currently only for debugging +impl Clone for ConnClient { + fn clone(&self) -> Self { + trace!( + "cloning ConnClient for key {:#?} with streamcount: {:?} / {:?}", + self.wl_key, + self.stream_count, + self.stream_count_max + ); + ConnClient { + sender: self.sender.clone(), + stream_count: self.stream_count.clone(), + stream_count_max: self.stream_count_max, + wl_key: self.wl_key.clone(), + } } } + +#[derive(PartialEq, Eq, Hash, Clone, Debug)] +pub struct WorkloadKey { + pub src_id: Identity, + pub dst_id: Vec, + // In theory we can just use src,dst,node. However, the dst has a check that + // the L3 destination IP matches the HBONE IP. This could be loosened to just assert they are the same identity maybe. + pub dst: SocketAddr, + // Because we spoof the source IP, we need to key on this as well. Note: for in-pod its already per-pod + // pools anyways. + pub src: IpAddr, +} + #[cfg(test)] mod test { use std::convert::Infallible; use std::net::SocketAddr; + use std::time::Instant; + use crate::identity; + + use drain::Watch; + use futures_util::StreamExt; use hyper::body::Incoming; + use hyper::service::service_fn; use hyper::{Request, Response}; - use tokio::net::{TcpListener, TcpStream}; - use tracing::{error, info}; + use std::sync::atomic::AtomicU32; + use std::time::Duration; + use tokio::io::AsyncWriteExt; + use tokio::net::TcpListener; + use tokio::task::{self}; + use tokio::time::sleep; + + #[cfg(tokio_unstable)] + use tracing::Instrument; + + use ztunnel::test_helpers::*; use super::*; - #[tokio::test] - async fn test_pool() { + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_pool_reuses_conn_for_same_key() { + // crate::telemetry::setup_logging(); + + let (server_drain_signal, server_drain) = drain::channel(); + + let conn_counter: Arc = Arc::new(AtomicU32::new(0)); + let conn_drop_counter: Arc = Arc::new(AtomicU32::new(0)); + let (server_addr, server_handle) = spawn_server( + server_drain, + conn_counter.clone(), + conn_drop_counter.clone(), + ) + .await; + + let cfg = crate::config::Config { + local_node: Some("local-node".to_string()), + pool_max_streams_per_conn: 6, + ..crate::config::parse_config().unwrap() + }; + let sock_fact = Arc::new(crate::proxy::DefaultSocketFactory); + let cert_mgr = identity::mock::new_secret_manager(Duration::from_secs(10)); + + let pool = WorkloadHBONEPool::new(cfg.clone(), sock_fact, cert_mgr); + + let key1 = WorkloadKey { + src_id: Identity::default(), + dst_id: vec![Identity::default()], + src: IpAddr::from([127, 0, 0, 2]), + dst: server_addr, + }; + let client1 = spawn_client(pool.clone(), key1.clone(), server_addr, 2).await; + let client2 = spawn_client(pool.clone(), key1.clone(), server_addr, 2).await; + let client3 = spawn_client(pool.clone(), key1, server_addr, 2).await; + + assert!(client1.is_ok()); + assert!(client2.is_ok()); + assert!(client3.is_ok()); + + server_drain_signal.drain().await; + drop(pool); + server_handle.await.unwrap(); + let real_conncount = conn_counter.load(Ordering::Relaxed); + assert!(real_conncount == 1, "actual conncount was {real_conncount}"); + + assert!(client1.is_ok()); + assert!(client2.is_ok()); + assert!(client3.is_ok()); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_pool_does_not_reuse_conn_for_diff_key() { + let (server_drain_signal, server_drain) = drain::channel(); + + let conn_counter: Arc = Arc::new(AtomicU32::new(0)); + let conn_drop_counter: Arc = Arc::new(AtomicU32::new(0)); + let (server_addr, server_handle) = spawn_server( + server_drain, + conn_counter.clone(), + conn_drop_counter.clone(), + ) + .await; + + // crate::telemetry::setup_logging(); + + let cfg = crate::config::Config { + local_node: Some("local-node".to_string()), + pool_max_streams_per_conn: 10, + ..crate::config::parse_config().unwrap() + }; + let sock_fact = Arc::new(crate::proxy::DefaultSocketFactory); + let cert_mgr = identity::mock::new_secret_manager(Duration::from_secs(10)); + let pool = WorkloadHBONEPool::new(cfg.clone(), sock_fact, cert_mgr); + + let key1 = WorkloadKey { + src_id: Identity::default(), + dst_id: vec![Identity::default()], + src: IpAddr::from([127, 0, 0, 2]), + dst: server_addr, + }; + let key2 = WorkloadKey { + src_id: Identity::default(), + dst_id: vec![Identity::default()], + src: IpAddr::from([127, 0, 0, 3]), + dst: server_addr, + }; + + let client1 = spawn_client(pool.clone(), key1, server_addr, 2).await; + let client2 = spawn_client(pool.clone(), key2, server_addr, 2).await; + + server_drain_signal.drain().await; + drop(pool); + + server_handle.await.unwrap(); + + let real_conncount = conn_counter.load(Ordering::Relaxed); + assert!(real_conncount == 2, "actual conncount was {real_conncount}"); + + assert!(client1.is_ok()); + assert!(client2.is_ok()); // expect this to panic - we used a new key + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_pool_respects_per_conn_stream_limit() { + let (server_drain_signal, server_drain) = drain::channel(); + + let conn_counter: Arc = Arc::new(AtomicU32::new(0)); + let conn_drop_counter: Arc = Arc::new(AtomicU32::new(0)); + let (server_addr, server_handle) = spawn_server( + server_drain, + conn_counter.clone(), + conn_drop_counter.clone(), + ) + .await; + + let cfg = crate::config::Config { + local_node: Some("local-node".to_string()), + pool_max_streams_per_conn: 3, + ..crate::config::parse_config().unwrap() + }; + let sock_fact = Arc::new(crate::proxy::DefaultSocketFactory); + let cert_mgr = identity::mock::new_secret_manager(Duration::from_secs(10)); + let pool = WorkloadHBONEPool::new(cfg.clone(), sock_fact, cert_mgr); + + let key1 = WorkloadKey { + src_id: Identity::default(), + dst_id: vec![Identity::default()], + src: IpAddr::from([127, 0, 0, 2]), + dst: server_addr, + }; + let client1 = spawn_client(pool.clone(), key1.clone(), server_addr, 4).await; + let client2 = spawn_client(pool.clone(), key1, server_addr, 2).await; + + server_drain_signal.drain().await; + drop(pool); + + server_handle.await.unwrap(); + + let real_conncount = conn_counter.load(Ordering::Relaxed); + assert!(real_conncount == 2, "actual conncount was {real_conncount}"); + + assert!(client1.is_ok()); + assert!(client2.is_ok()); // expect this to panic - same key, but stream limit of 3 + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_pool_handles_many_conns_per_key() { + let (server_drain_signal, server_drain) = drain::channel(); + + let conn_counter: Arc = Arc::new(AtomicU32::new(0)); + let conn_drop_counter: Arc = Arc::new(AtomicU32::new(0)); + let (server_addr, server_handle) = spawn_server( + server_drain, + conn_counter.clone(), + conn_drop_counter.clone(), + ) + .await; + + let cfg = crate::config::Config { + local_node: Some("local-node".to_string()), + pool_max_streams_per_conn: 2, + ..crate::config::parse_config().unwrap() + }; + let sock_fact = Arc::new(crate::proxy::DefaultSocketFactory); + let cert_mgr = identity::mock::new_secret_manager(Duration::from_secs(10)); + + let pool = WorkloadHBONEPool::new(cfg.clone(), sock_fact, cert_mgr); + + let key1 = WorkloadKey { + src_id: Identity::default(), + dst_id: vec![Identity::default()], + src: IpAddr::from([127, 0, 0, 2]), + dst: server_addr, + }; + let client1 = spawn_client(pool.clone(), key1.clone(), server_addr, 4).await; + let client2 = spawn_client(pool.clone(), key1.clone(), server_addr, 4).await; + + drop(pool); + server_drain_signal.drain().await; + + server_handle.await.unwrap(); + + let real_conncount = conn_counter.load(Ordering::Relaxed); + assert!(real_conncount == 2, "actual conncount was {real_conncount}"); + + assert!(client1.is_ok()); + assert!(client2.is_ok()); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_pool_100_clients_streamexhaust() { + // crate::telemetry::setup_logging(); + + let (server_drain_signal, server_drain) = drain::channel(); + + let conn_counter: Arc = Arc::new(AtomicU32::new(0)); + let conn_drop_counter: Arc = Arc::new(AtomicU32::new(0)); + let (server_addr, server_handle) = spawn_server( + server_drain, + conn_counter.clone(), + conn_drop_counter.clone(), + ) + .await; + + let cfg = crate::config::Config { + local_node: Some("local-node".to_string()), + pool_max_streams_per_conn: 25, + ..crate::config::parse_config().unwrap() + }; + let sock_fact = Arc::new(crate::proxy::DefaultSocketFactory); + let cert_mgr = identity::mock::new_secret_manager(Duration::from_secs(10)); + let pool = WorkloadHBONEPool::new(cfg.clone(), sock_fact, cert_mgr); + + let key1 = WorkloadKey { + src_id: Identity::default(), + dst_id: vec![Identity::default()], + src: IpAddr::from([127, 0, 0, 2]), + dst: server_addr, + }; + let client_count = 50; + let mut count = 0u32; + let mut tasks = futures::stream::FuturesUnordered::new(); + loop { + count += 1; + tasks.push(spawn_client(pool.clone(), key1.clone(), server_addr, 1)); + if count == client_count { + break; + } + } + + // TODO we spawn clients too fast (and they have little to do) and they actually break the + // local "fake" test server, causing it to start returning "conn refused/peer refused the connection" + // when the pool tries to create new connections for that caller + // + // (the pool will just pass that conn refused back to the caller) + // + // In the real world this is fine, since we aren't hitting a local server, + // servers can refuse connections - in synthetic tests it leads to flakes. + // + // It is worth considering if the pool should throttle how frequently it allows itself to create + // connections to real upstreams (e.g. "I created a conn for this key 10ms ago and you've already burned through + // your streamcount, chill out, you're gonna overload the dest") + // + // For now, streamcount is an inexact flow control for this. + sleep(Duration::from_millis(500)).await; + + while let Some(Err(res)) = tasks.next().await { + assert!(!res.is_panic(), "CLIENT PANICKED!"); + continue; + } + + server_drain_signal.drain().await; + server_handle.await.unwrap(); + drop(pool); + + let real_conncount = conn_counter.load(Ordering::SeqCst); + assert!(real_conncount == 2, "actual conncount was {real_conncount}"); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_pool_100_clients_singleconn() { + // crate::telemetry::setup_logging(); + + let (server_drain_signal, server_drain) = drain::channel(); + + let conn_counter: Arc = Arc::new(AtomicU32::new(0)); + let conn_drop_counter: Arc = Arc::new(AtomicU32::new(0)); + let (server_addr, server_handle) = spawn_server( + server_drain, + conn_counter.clone(), + conn_drop_counter.clone(), + ) + .await; + + let cfg = crate::config::Config { + local_node: Some("local-node".to_string()), + pool_max_streams_per_conn: 1000, + ..crate::config::parse_config().unwrap() + }; + let sock_fact = Arc::new(crate::proxy::DefaultSocketFactory); + let cert_mgr = identity::mock::new_secret_manager(Duration::from_secs(10)); + let pool = WorkloadHBONEPool::new(cfg.clone(), sock_fact, cert_mgr); + + let key1 = WorkloadKey { + src_id: Identity::default(), + dst_id: vec![Identity::default()], + src: IpAddr::from([127, 0, 0, 2]), + dst: server_addr, + }; + let client_count = 100; + let mut count = 0u32; + let mut tasks = futures::stream::FuturesUnordered::new(); + loop { + count += 1; + tasks.push(spawn_client(pool.clone(), key1.clone(), server_addr, 1)); + + if count == client_count { + break; + } + } + while let Some(Err(res)) = tasks.next().await { + assert!(!res.is_panic(), "CLIENT PANICKED!"); + continue; + } + + drop(pool); + + server_drain_signal.drain().await; + server_handle.await.unwrap(); + + let real_conncount = conn_counter.load(Ordering::Relaxed); + assert!(real_conncount == 1, "actual conncount was {real_conncount}"); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_pool_100_clients_100_srcs() { + // crate::telemetry::setup_logging(); + + let (server_drain_signal, server_drain) = drain::channel(); + + let conn_counter: Arc = Arc::new(AtomicU32::new(0)); + let conn_drop_counter: Arc = Arc::new(AtomicU32::new(0)); + let (server_addr, server_handle) = spawn_server( + server_drain, + conn_counter.clone(), + conn_drop_counter.clone(), + ) + .await; + + let cfg = crate::config::Config { + local_node: Some("local-node".to_string()), + pool_max_streams_per_conn: 100, + ..crate::config::parse_config().unwrap() + }; + let sock_fact = Arc::new(crate::proxy::DefaultSocketFactory); + let cert_mgr = identity::mock::new_secret_manager(Duration::from_secs(10)); + let pool = WorkloadHBONEPool::new(cfg.clone(), sock_fact, cert_mgr); + + let client_count = 100; + let mut count = 0u8; + let mut tasks = futures::stream::FuturesUnordered::new(); + loop { + count += 1; + + let key1 = WorkloadKey { + src_id: Identity::default(), + dst_id: vec![Identity::default()], + src: IpAddr::from([127, 0, 0, count]), + dst: server_addr, + }; + + tasks.push(spawn_client(pool.clone(), key1.clone(), server_addr, 20)); + + if count == client_count { + break; + } + } + + while let Some(Err(res)) = tasks.next().await { + assert!(!res.is_panic(), "CLIENT PANICKED!"); + continue; + } + + drop(pool); + + server_drain_signal.drain().await; + server_handle.await.unwrap(); + + let real_conncount = conn_counter.load(Ordering::Relaxed); + assert!( + real_conncount == 100, + "actual conncount was {real_conncount}" + ); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_pool_1000_clients_3_srcs() { + // crate::telemetry::setup_logging(); + + let (server_drain_signal, server_drain) = drain::channel(); + + let conn_counter: Arc = Arc::new(AtomicU32::new(0)); + let conn_drop_counter: Arc = Arc::new(AtomicU32::new(0)); + let (server_addr, server_handle) = spawn_server( + server_drain, + conn_counter.clone(), + conn_drop_counter.clone(), + ) + .await; + + let cfg = crate::config::Config { + local_node: Some("local-node".to_string()), + pool_max_streams_per_conn: 1000, + ..crate::config::parse_config().unwrap() + }; + let sock_fact = Arc::new(crate::proxy::DefaultSocketFactory); + let cert_mgr = identity::mock::new_secret_manager(Duration::from_secs(10)); + let pool = WorkloadHBONEPool::new(cfg.clone(), sock_fact, cert_mgr); + + let mut key1 = WorkloadKey { + src_id: Identity::default(), + dst_id: vec![Identity::default()], + src: IpAddr::from([127, 0, 0, 1]), + dst: server_addr, + }; + + let client_count = 100; + let mut count = 0u32; + let mut tasks = futures::stream::FuturesUnordered::new(); + loop { + count += 1; + if count % 2 == 0 { + debug!("using key 2"); + key1.src = IpAddr::from([127, 0, 0, 4]); + } else if count % 3 == 0 { + debug!("using key 3"); + key1.src = IpAddr::from([127, 0, 0, 6]); + } else { + debug!("using key 1"); + key1.src = IpAddr::from([127, 0, 0, 2]); + } + + tasks.push(spawn_client(pool.clone(), key1.clone(), server_addr, 50)); + + if count == client_count { + break; + } + } + while let Some(Err(res)) = tasks.next().await { + assert!(!res.is_panic(), "CLIENT PANICKED!"); + continue; + } + + drop(pool); + + server_drain_signal.drain().await; + server_handle.await.unwrap(); + + let real_conncount = conn_counter.load(Ordering::Relaxed); + assert!(real_conncount == 3, "actual conncount was {real_conncount}"); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_pool_1000_clients_3_srcs_drops_after_timeout() { + // crate::telemetry::setup_logging(); + + let (server_drain_signal, server_drain) = drain::channel(); + + let conn_counter: Arc = Arc::new(AtomicU32::new(0)); + let conn_drop_counter: Arc = Arc::new(AtomicU32::new(0)); + let (server_addr, server_handle) = spawn_server( + server_drain, + conn_counter.clone(), + conn_drop_counter.clone(), + ) + .await; + + let cfg = crate::config::Config { + local_node: Some("local-node".to_string()), + pool_max_streams_per_conn: 1000, + pool_unused_release_timeout: Duration::from_secs(1), + ..crate::config::parse_config().unwrap() + }; + let sock_fact = Arc::new(crate::proxy::DefaultSocketFactory); + let cert_mgr = identity::mock::new_secret_manager(Duration::from_secs(10)); + let pool = WorkloadHBONEPool::new(cfg.clone(), sock_fact, cert_mgr); + + let mut key1 = WorkloadKey { + src_id: Identity::default(), + dst_id: vec![Identity::default()], + src: IpAddr::from([127, 0, 0, 1]), + dst: server_addr, + }; + + let client_count = 100; + let mut count = 0u32; + let mut tasks = futures::stream::FuturesUnordered::new(); + loop { + count += 1; + if count % 2 == 0 { + debug!("using key 2"); + key1.src = IpAddr::from([127, 0, 0, 4]); + } else if count % 3 == 0 { + debug!("using key 3"); + key1.src = IpAddr::from([127, 0, 0, 6]); + } else { + debug!("using key 1"); + key1.src = IpAddr::from([127, 0, 0, 2]); + } + + tasks.push(spawn_client(pool.clone(), key1.clone(), server_addr, 50)); + + if count == client_count { + break; + } + } + while let Some(Err(res)) = tasks.next().await { + assert!(!res.is_panic(), "CLIENT PANICKED!"); + continue; + } + + let before_conncount = conn_counter.load(Ordering::Relaxed); + let before_dropcount = conn_drop_counter.load(Ordering::Relaxed); + assert!( + before_conncount == 3, + "actual before conncount was {before_conncount}" + ); + assert!( + before_dropcount != 3, + "actual before dropcount was {before_dropcount}" + ); + + // Attempt to wait long enough for pool conns to timeout+evict + sleep(Duration::from_secs(1)).await; + + let real_conncount = conn_counter.load(Ordering::Relaxed); + let real_dropcount = conn_drop_counter.load(Ordering::Relaxed); + assert!(real_conncount == 3, "actual conncount was {real_conncount}"); + assert!(real_dropcount == 3, "actual dropcount was {real_dropcount}"); + + server_drain_signal.drain().await; + server_handle.await.unwrap(); + drop(pool); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_pool_100_clients_evicts_but_does_not_close_active_conn() { + // crate::telemetry::setup_logging(); + + let (server_drain_signal, server_drain) = drain::channel(); + + let conn_counter: Arc = Arc::new(AtomicU32::new(0)); + let conn_drop_counter: Arc = Arc::new(AtomicU32::new(0)); + let (server_addr, server_handle) = spawn_server( + server_drain, + conn_counter.clone(), + conn_drop_counter.clone(), + ) + .await; + + let cfg = crate::config::Config { + local_node: Some("local-node".to_string()), + pool_max_streams_per_conn: 50, + pool_unused_release_timeout: Duration::from_secs(1), + ..crate::config::parse_config().unwrap() + }; + let sock_fact = Arc::new(crate::proxy::DefaultSocketFactory); + let cert_mgr = identity::mock::new_secret_manager(Duration::from_secs(10)); + let pool = WorkloadHBONEPool::new(cfg.clone(), sock_fact, cert_mgr); + + let key1 = WorkloadKey { + src_id: Identity::default(), + dst_id: vec![Identity::default()], + src: IpAddr::from([127, 0, 0, 2]), + dst: server_addr, + }; + let client_count = 100; + let mut count = 0u32; + let mut tasks = futures::stream::FuturesUnordered::new(); + loop { + count += 1; + tasks.push(spawn_client(pool.clone(), key1.clone(), server_addr, 1)); + + if count == client_count { + break; + } + } + + // TODO we spawn clients too fast (and they have little to do) and they actually break the + // local "fake" test server, causing it to start returning "conn refused/peer refused the connection" + // when the pool tries to create new connections for that caller + // + // (the pool will just pass that conn refused back to the caller) + // + // In the real world this is fine, since we aren't hitting a local server, + // servers can refuse connections - in synthetic tests it leads to flakes. + // + // It is worth considering if the pool should throttle how frequently it allows itself to create + // connections to real upstreams (e.g. "I created a conn for this key 10ms ago and you've already burned through + // your streamcount, chill out, you're gonna overload the dest") + // + // For now, streamcount is an inexact flow control for this. + sleep(Duration::from_millis(500)).await; + //loop thru the nonpersistent clients and wait for them to finish + while let Some(Err(res)) = tasks.next().await { + assert!(!res.is_panic(), "CLIENT PANICKED!"); + continue; + } + + let (client_stop_signal, client_stop) = drain::channel(); + let persist_res = + spawn_persistent_client(pool.clone(), key1.clone(), server_addr, client_stop); + + //Attempt to wait a bit more, to ensure the connections NOT held open by our persistent client are dropped. + sleep(Duration::from_secs(1)).await; + let before_conncount = conn_counter.load(Ordering::Relaxed); + let before_dropcount = conn_drop_counter.load(Ordering::Relaxed); + assert!( + before_conncount == 3, + "actual before conncount was {before_conncount}" + ); + // At this point, we should still have one conn that hasn't been dropped + // because we haven't ended the persistent client + assert!( + before_dropcount == 2, + "actual before dropcount was {before_dropcount}" + ); + + client_stop_signal.drain().await; + assert!(persist_res.await.is_ok(), "PERSIST CLIENT ERROR"); + + //Attempt to wait a bit more, to ensure the connections held open by our persistent client is dropped. + sleep(Duration::from_secs(1)).await; + + let after_conncount = conn_counter.load(Ordering::Relaxed); + assert!( + after_conncount == 3, + "after conncount was {after_conncount}" + ); + let after_dropcount = conn_drop_counter.load(Ordering::Relaxed); + assert!( + after_dropcount == 3, + "after dropcount was {after_dropcount}" + ); + server_drain_signal.drain().await; + server_handle.await.unwrap(); + + drop(pool); + } + + fn spawn_client( + mut pool: WorkloadHBONEPool, + key: WorkloadKey, + remote_addr: SocketAddr, + req_count: u32, + ) -> task::JoinHandle<()> { + tokio::spawn(async move { + let req = || { + hyper::Request::builder() + .uri(format!("{remote_addr}")) + .method(hyper::Method::CONNECT) + .version(hyper::Version::HTTP_2) + .body(Empty::::new()) + .unwrap() + }; + + let start = Instant::now(); + + let mut c1 = pool + .connect(key.clone()) + // needs tokio_unstable, but useful + // .instrument(tracing::debug_span!("client_tid", tid=%tokio::task::id())) + .await + .expect("connect should succeed"); + debug!( + "client spent {}ms waiting for conn", + start.elapsed().as_millis() + ); + + let mut count = 0u32; + loop { + count += 1; + let res = c1.send_request(req()).await; + + if res.is_err() { + panic!("SEND ERR: {:#?} sendcount {count}", res); + } + + if count >= req_count { + debug!("CLIENT DONE"); + break; + } + } + }) + } + + fn spawn_persistent_client( + mut pool: WorkloadHBONEPool, + key: WorkloadKey, + remote_addr: SocketAddr, + stop: Watch, + ) -> task::JoinHandle<()> { + tokio::spawn(async move { + let req = || { + hyper::Request::builder() + .uri(format!("{remote_addr}")) + .method(hyper::Method::CONNECT) + .version(hyper::Version::HTTP_2) + .body(Empty::::new()) + .unwrap() + }; + + let start = Instant::now(); + + let mut c1 = pool + .connect(key.clone()) + // needs tokio_unstable, but useful + // .instrument(tracing::debug_span!("client_tid", tid=%tokio::task::id())) + .await + .unwrap(); + debug!( + "client spent {}ms waiting for conn", + start.elapsed().as_millis() + ); + + let send_loop = async move { + //send once, then hold the conn open until signaled + let res = c1.send_request(req()).await; + if res.is_err() { + panic!("SEND ERR: {:#?}", res); + } + loop { + debug!("persistent client yielding"); + sleep(Duration::from_millis(1)).await; //yield may be enough + tokio::task::yield_now().await; + } + }; + + tokio::select! { + _ = send_loop => {} + _ = stop.signaled() => { + debug!("GOT STOP PERSISTENT CLIENT"); + } + }; + }) + } + + async fn spawn_server( + stop: Watch, + conn_count: Arc, + conn_drop_count: Arc, + ) -> (SocketAddr, task::JoinHandle<()>) { // We'll bind to 127.0.0.1:3000 let addr = SocketAddr::from(([127, 0, 0, 1], 0)); + let test_cfg = test_config(); async fn hello_world(req: Request) -> Result>, Infallible> { - info!("got req {req:?}"); - Ok(Response::builder().status(200).body(Empty::new()).unwrap()) + debug!("hello world: received request"); + tokio::task::spawn(async move { + match hyper::upgrade::on(req).await { + Ok(upgraded) => { + let mut io = hyper_util::rt::TokioIo::new(upgraded); + io.write_all(b"poolsrv\n").await.unwrap(); + tcp::handle_stream(tcp::Mode::ReadWrite, &mut io).await; + } + Err(e) => panic!("No upgrade {e}"), + } + }); + Ok::<_, Infallible>(Response::new(http_body_util::Empty::::new())) } // We create a TcpListener and bind it to 127.0.0.1:3000 let listener = TcpListener::bind(addr).await.unwrap(); + let bound_addr = listener.local_addr().unwrap(); - let addr = listener.local_addr().unwrap(); - tokio::spawn(async move { + let certs = crate::tls::mock::generate_test_certs( + &Identity::default().into(), + Duration::from_secs(0), + Duration::from_secs(100), + ); + let acceptor = crate::tls::mock::MockServerCertProvider::new(certs); + let mut tls_stream = crate::hyper_util::tls_server(acceptor, listener); + + let srv_handle = tokio::spawn(async move { // We start a loop to continuously accept incoming connections - loop { - let (stream, _) = listener.accept().await.unwrap(); - - // Spawn a tokio task to serve multiple connections concurrently - tokio::task::spawn(async move { - // Finally, we bind the incoming connection to our `hello` service - if let Err(err) = crate::hyper_util::http2_server() - .serve_connection( - hyper_util::rt::TokioIo::new(stream), - service_fn(hello_world), - ) - .await - { - println!("Error serving connection: {:?}", err); - } - }); - } - }); - let pool = Pool::new(); - let key = Key { - src_id: Identity::default(), - dst_id: vec![Identity::default()], - src: IpAddr::from([127, 0, 0, 2]), - dst: addr, - }; - let connect = || async { - let builder = http2::Builder::new(TokioExec); - - let tcp_stream = TcpStream::connect(addr).await?; - let (request_sender, connection) = builder - .handshake(hyper_util::rt::TokioIo::new(tcp_stream)) - .await?; - // spawn a task to poll the connection and drive the HTTP state - tokio::spawn(async move { - if let Err(e) = connection.await { - error!("Error in connection handshake: {:?}", e); + // and also count them + let movable_count = conn_count.clone(); + let movable_drop_count = conn_drop_count.clone(); + let accept = async move { + loop { + let stream = tls_stream.next().await.unwrap(); + movable_count.fetch_add(1, Ordering::Relaxed); + let dcount = movable_drop_count.clone(); + debug!("bump serverconn"); + + // Spawn a tokio task to serve multiple connections concurrently + tokio::task::spawn(async move { + // Finally, we bind the incoming connection to our `hello` service + if let Err(err) = crate::hyper_util::http2_server() + .initial_stream_window_size(test_cfg.window_size) + .initial_connection_window_size(test_cfg.connection_window_size) + .max_frame_size(test_cfg.frame_size) + // 64KB max; default is 16MB driven from Golang's defaults + // Since we know we are going to recieve a bounded set of headers, more is overkill. + .max_header_list_size(65536) + .serve_connection( + hyper_util::rt::TokioIo::new(stream), + service_fn(hello_world), + ) + .await + { + println!("Error serving connection: {:?}", err); + } + dcount.fetch_add(1, Ordering::Relaxed); + }); } - }); - Ok(request_sender) - }; - let req = || { - hyper::Request::builder() - .uri(format!("http://{addr}")) - .method(hyper::Method::GET) - .version(hyper::Version::HTTP_2) - .body(Empty::::new()) - .unwrap() - }; - let mut c1 = pool.connect(key.clone(), connect()).await.unwrap(); - let mut c2 = pool - .connect(key, async { unreachable!("should use pooled connection") }) - .await - .unwrap(); - assert_eq!(c1.send_request(req()).await.unwrap().status(), 200); - assert_eq!(c1.send_request(req()).await.unwrap().status(), 200); - assert_eq!(c2.send_request(req()).await.unwrap().status(), 200); + }; + tokio::select! { + _ = accept => {} + _ = stop.signaled() => { + debug!("GOT STOP SERVER"); + } + }; + }); + + (bound_addr, srv_handle) } } diff --git a/src/proxy/socks5.rs b/src/proxy/socks5.rs index d935d5ab4..0346a7e2b 100644 --- a/src/proxy/socks5.rs +++ b/src/proxy/socks5.rs @@ -65,12 +65,20 @@ impl Socks5 { let socket = self.listener.accept().await; let inpod = self.pi.cfg.inpod_enabled; let stream_drain = inner_drain.clone(); + // TODO creating a new HBONE pool for SOCKS5 here may not be ideal, + // but ProxyInfo is overloaded and only `outbound` should ever use the pool. + let pool = crate::proxy::pool::WorkloadHBONEPool::new( + self.pi.cfg.clone(), + self.pi.socket_factory.clone(), + self.pi.cert_manager.clone(), + ); match socket { Ok((stream, remote)) => { info!("accepted outbound connection from {}", remote); let oc = OutboundConnection { pi: self.pi.clone(), id: TraceParent::new(), + pool, }; tokio::spawn(async move { if let Err(err) = handle(oc, stream, stream_drain, inpod).await {