diff --git a/src/proxy.rs b/src/proxy.rs index 7bde75794..5c7a6d474 100644 --- a/src/proxy.rs +++ b/src/proxy.rs @@ -32,6 +32,7 @@ pub use metrics::*; use crate::identity::SecretManager; use crate::metrics::Recorder; +use crate::proxy::connection_manager::{ConnectionManager, PolicyWatcher}; use crate::proxy::inbound_passthrough::InboundPassthrough; use crate::proxy::outbound::Outbound; use crate::proxy::socks5::Socks5; @@ -41,7 +42,7 @@ use crate::state::workload::{network_addr, Workload}; use crate::state::DemandProxyState; use crate::{config, identity, socket, tls}; -mod connection_manager; +pub mod connection_manager; mod inbound; mod inbound_passthrough; #[allow(non_camel_case_types)] @@ -91,12 +92,14 @@ pub struct Proxy { inbound_passthrough: InboundPassthrough, outbound: Outbound, socks5: Socks5, + policy_watcher: PolicyWatcher, } #[derive(Clone)] pub(super) struct ProxyInputs { cfg: config::Config, cert_manager: Arc, + connection_manager: ConnectionManager, hbone_port: u16, pub state: DemandProxyState, metrics: Arc, @@ -108,6 +111,7 @@ impl ProxyInputs { pub fn new( cfg: config::Config, cert_manager: Arc, + connection_manager: ConnectionManager, state: DemandProxyState, metrics: Arc, socket_factory: Arc, @@ -117,6 +121,7 @@ impl ProxyInputs { state, cert_manager, metrics, + connection_manager, pool: pool::Pool::new(), hbone_port: 0, socket_factory, @@ -137,6 +142,7 @@ impl Proxy { cfg, state, cert_manager, + connection_manager: ConnectionManager::default(), metrics, pool: pool::Pool::new(), hbone_port: 0, @@ -152,12 +158,14 @@ impl Proxy { let inbound_passthrough = InboundPassthrough::new(pi.clone(), drain.clone()).await?; let outbound = Outbound::new(pi.clone(), drain.clone()).await?; let socks5 = Socks5::new(pi.clone(), drain.clone()).await?; + let policy_watcher = PolicyWatcher::new(pi.state, drain, pi.connection_manager); Ok(Proxy { inbound, inbound_passthrough, outbound, socks5, + policy_watcher, }) } @@ -167,6 +175,7 @@ impl Proxy { tokio::spawn(self.inbound.run().in_current_span()), tokio::spawn(self.outbound.run().in_current_span()), tokio::spawn(self.socks5.run().in_current_span()), + tokio::spawn(self.policy_watcher.run().in_current_span()), ]; futures::future::join_all(tasks).await; diff --git a/src/proxy/connection_manager.rs b/src/proxy/connection_manager.rs index 3b503620e..75d549520 100644 --- a/src/proxy/connection_manager.rs +++ b/src/proxy/connection_manager.rs @@ -18,7 +18,6 @@ use crate::state::DemandProxyState; use drain; use std::collections::HashMap; use std::sync::Arc; -use tokio::sync::watch; use tokio::sync::RwLock; use tracing::info; @@ -51,13 +50,15 @@ pub struct ConnectionManager { drains: Arc>>, } -impl ConnectionManager { - pub fn new() -> Self { +impl std::default::Default for ConnectionManager { + fn default() -> Self { ConnectionManager { drains: Arc::new(RwLock::new(HashMap::new())), } } +} +impl ConnectionManager { // register a connection with the connection manager // this must be done before a connection can be tracked // allows policy to be asserted against the connection @@ -119,24 +120,39 @@ impl ConnectionManager { } } -pub async fn policy_watcher( +pub struct PolicyWatcher { state: DemandProxyState, - mut stop_rx: watch::Receiver<()>, + stop: drain::Watch, connection_manager: ConnectionManager, - parent_proxy: &str, -) { - let mut policies_changed = state.read().policies.subscribe(); - loop { - tokio::select! { - _ = stop_rx.changed() => { - break; - } - _ = policies_changed.changed() => { - let connections = connection_manager.connections().await; - for conn in connections { - if !state.assert_rbac(&conn).await { - connection_manager.close(&conn).await; - info!("{parent_proxy} connection {conn} closed because it's no longer allowed after a policy update"); +} + +impl PolicyWatcher { + pub fn new( + state: DemandProxyState, + stop: drain::Watch, + connection_manager: ConnectionManager, + ) -> Self { + PolicyWatcher { + state, + stop, + connection_manager, + } + } + + pub async fn run(self) { + let mut policies_changed = self.state.read().policies.subscribe(); + loop { + tokio::select! { + _ = self.stop.clone().signaled() => { + break; + } + _ = policies_changed.changed() => { + let connections = self.connection_manager.connections().await; + for conn in connections { + if !self.state.assert_rbac(&conn).await { + self.connection_manager.close(&conn).await; + info!("connection {conn} closed because it's no longer allowed after a policy update"); + } } } } @@ -151,19 +167,18 @@ mod tests { use std::net::{Ipv4Addr, SocketAddrV4}; use std::sync::{Arc, RwLock}; use std::time::Duration; - use tokio::sync::watch; use crate::rbac::Connection; use crate::state::{DemandProxyState, ProxyState}; use crate::xds::istio::security::{Action, Authorization, Scope}; use crate::xds::ProxyStateUpdateMutator; - use super::ConnectionManager; + use super::{ConnectionManager, PolicyWatcher}; #[tokio::test] async fn test_connection_manager_close() { // setup a new ConnectionManager - let connection_manager = ConnectionManager::new(); + let connection_manager = ConnectionManager::default(); // ensure drains is empty assert_eq!(connection_manager.drains.read().await.len(), 0); assert_eq!(connection_manager.connections().await.len(), 0); @@ -252,7 +267,7 @@ mod tests { #[tokio::test] async fn test_connection_manager_release() { // setup a new ConnectionManager - let connection_manager = ConnectionManager::new(); + let connection_manager = ConnectionManager::default(); // ensure drains is empty assert_eq!(connection_manager.drains.read().await.len(), 0); assert_eq!(connection_manager.connections().await.len(), 0); @@ -355,21 +370,17 @@ mod tests { ResolverConfig::default(), ResolverOpts::default(), ); - let connection_manager = ConnectionManager::new(); - let parent_proxy = "test"; - let (stop_tx, stop_rx) = watch::channel(()); + let connection_manager = ConnectionManager::default(); + let (tx, stop) = drain::channel(); let state_mutator = ProxyStateUpdateMutator::new_no_fetch(); // clones to move into spawned task let ds = dstate.clone(); let cm = connection_manager.clone(); + let pw = PolicyWatcher::new(ds, stop, cm); // spawn a task which watches policy and asserts that the policy watcher stop correctly tokio::spawn(async move { - let res = tokio::time::timeout( - Duration::from_secs(1), - super::policy_watcher(ds, stop_rx, cm, parent_proxy), - ) - .await; + let res = tokio::time::timeout(Duration::from_secs(1), pw.run()).await; assert!(res.is_ok()) }); @@ -399,18 +410,19 @@ mod tests { // spawn an assertion that our connection close is received tokio::spawn(assert_close(close1)); - // update our state - let mut s = state - .write() - .expect("test fails if we're unable to get a write lock on state"); - let res = state_mutator.insert_authorization(&mut s, auth); - // assert that the update was OK - assert!(res.is_ok()); - // release lock - drop(s); + // this block will scope our guard appropriately + { + // update our state + let mut s = state + .write() + .expect("test fails if we're unable to get a write lock on state"); + let res = state_mutator.insert_authorization(&mut s, auth); + // assert that the update was OK + assert!(res.is_ok()); + } // release lock // send the signal which stops policy watcher - stop_tx.send_replace(()); + tx.drain().await; } // small helper to assert that the Watches are working in a timely manner diff --git a/src/proxy/inbound.rs b/src/proxy/inbound.rs index a8a3acf49..2b051e8e8 100644 --- a/src/proxy/inbound.rs +++ b/src/proxy/inbound.rs @@ -26,10 +26,9 @@ use hyper::body::Incoming; use hyper::service::service_fn; use hyper::{Method, Request, Response, StatusCode}; use tokio::net::{TcpListener, TcpStream}; -use tokio::sync::watch; use tracing::{debug, error, info, instrument, trace, trace_span, warn, Instrument}; -use super::connection_manager::{self, ConnectionManager}; +use super::connection_manager::ConnectionManager; use super::{Error, SocketFactory}; use crate::baggage::parse_baggage_header; use crate::config::Config; @@ -80,7 +79,7 @@ impl Inbound { metrics: pi.metrics, drain, socket_factory: pi.socket_factory.clone(), - connection_manager: ConnectionManager::new(), + connection_manager: pi.connection_manager, }) } @@ -99,17 +98,6 @@ impl Inbound { let mut stream = stream.take_until(Box::pin(self.drain.signaled())); let (sub_drain_signal, sub_drain) = drain::channel(); - // spawn a task which subscribes to watch updates and asserts rbac against this proxy's connections, closing the ones which have become denied - let (stop_tx, stop_rx) = watch::channel(()); - let state = self.state.clone(); - let connection_manager = self.connection_manager.clone(); - - tokio::spawn(connection_manager::policy_watcher( - state, - stop_rx, - connection_manager, - "inbound", - )); while let Some(socket) = stream.next().await { let state = self.state.clone(); @@ -163,7 +151,6 @@ impl Inbound { }); } info!("draining connections"); - stop_tx.send_replace(()); // close the task handling auth updates drop(sub_drain); // sub_drain_signal.drain() will never resolve while sub_drain is valid, will deadlock if not dropped sub_drain_signal.drain().await; info!("all inbound connections drained"); diff --git a/src/proxy/inbound_passthrough.rs b/src/proxy/inbound_passthrough.rs index 5489a08e2..dba6ea2a0 100644 --- a/src/proxy/inbound_passthrough.rs +++ b/src/proxy/inbound_passthrough.rs @@ -16,11 +16,10 @@ use std::net::SocketAddr; use drain::Watch; use tokio::net::{TcpListener, TcpStream}; -use tokio::sync::watch; use tracing::{error, info, trace, warn, Instrument}; use crate::config::ProxyMode; -use crate::proxy::connection_manager::{self, ConnectionManager}; +use crate::proxy::connection_manager::ConnectionManager; use crate::proxy::metrics::Reporter; use crate::proxy::outbound::OutboundConnection; use crate::proxy::{metrics, util, ProxyInputs}; @@ -33,7 +32,6 @@ pub(super) struct InboundPassthrough { listener: TcpListener, pi: ProxyInputs, drain: Watch, - connection_manager: ConnectionManager, } impl InboundPassthrough { @@ -60,29 +58,17 @@ impl InboundPassthrough { listener, pi, drain, - connection_manager: ConnectionManager::new(), }) } pub(super) async fn run(self) { - // spawn a task which subscribes to watch updates and asserts rbac against this proxy's connections, closing the ones which have become denied - let (stop_tx, stop_rx) = watch::channel(()); - let connection_manager = self.connection_manager.clone(); - let state = self.pi.state.clone(); - - tokio::spawn(connection_manager::policy_watcher( - state, - stop_rx, - connection_manager, - "inbound_passthrough", - )); let accept = async move { loop { // Asynchronously wait for an inbound socket. let socket = self.listener.accept().await; let pi = self.pi.clone(); - let connection_manager = self.connection_manager.clone(); + let connection_manager = self.pi.connection_manager.clone(); match socket { Ok((stream, remote)) => { tokio::spawn(async move { @@ -114,7 +100,6 @@ impl InboundPassthrough { res = accept => { res } _ = self.drain.signaled() => { info!("inbound passthrough drained"); - stop_tx.send_replace(()); } } } @@ -149,7 +134,6 @@ impl InboundPassthrough { let mut oc = OutboundConnection { pi: pi.clone(), id: TraceParent::new(), - connection_manager, }; // Spoofing the source IP only works when the destination or the source are on our node. // In this case, the source and the destination might both be remote, so we need to disable it. diff --git a/src/proxy/outbound.rs b/src/proxy/outbound.rs index c50af788e..44cb2468e 100644 --- a/src/proxy/outbound.rs +++ b/src/proxy/outbound.rs @@ -28,7 +28,6 @@ use tracing::{debug, error, info, info_span, trace, trace_span, warn, Instrument use crate::config::ProxyMode; use crate::identity::Identity; -use crate::proxy::connection_manager::ConnectionManager; use crate::proxy::inbound::{Inbound, InboundConnect}; use crate::proxy::metrics::Reporter; use crate::proxy::{metrics, pool}; @@ -43,7 +42,6 @@ pub struct Outbound { pi: ProxyInputs, drain: Watch, listener: TcpListener, - connection_manager: ConnectionManager, } impl Outbound { @@ -66,7 +64,6 @@ impl Outbound { pi, listener, drain, - connection_manager: ConnectionManager::new(), }) } @@ -85,7 +82,6 @@ impl Outbound { let mut oc = OutboundConnection { pi: self.pi.clone(), id: TraceParent::new(), - connection_manager: self.connection_manager.clone(), }; let span = info_span!("outbound", id=%oc.id); tokio::spawn( @@ -124,7 +120,6 @@ impl Outbound { pub(super) struct OutboundConnection { pub(super) pi: ProxyInputs, pub(super) id: TraceParent, - pub(super) connection_manager: ConnectionManager, } impl OutboundConnection { @@ -193,9 +188,9 @@ impl OutboundConnection { dst_network: req.source.network.clone(), // since this is node local, it's the same network dst: req.destination, }; - self.connection_manager.register(&conn).await; + self.pi.connection_manager.register(&conn).await; if !self.pi.state.assert_rbac(&conn).await { - self.connection_manager.release(&conn).await; + self.pi.connection_manager.release(&conn).await; info!(%conn, "RBAC rejected"); return Err(Error::HttpStatus(StatusCode::UNAUTHORIZED)); } @@ -220,7 +215,7 @@ impl OutboundConnection { connection_metrics, Some(inbound_connection_metrics), self.pi.socket_factory.as_ref(), - self.connection_manager.clone(), + self.pi.connection_manager.clone(), conn, ) .await @@ -625,9 +620,9 @@ mod tests { metrics: test_proxy_metrics(), pool: pool::Pool::new(), socket_factory: std::sync::Arc::new(crate::proxy::DefaultSocketFactory), + connection_manager: ConnectionManager::default(), }, id: TraceParent::new(), - connection_manager: ConnectionManager::new(), }; let req = outbound diff --git a/src/proxy/socks5.rs b/src/proxy/socks5.rs index 0f9eac442..8c32e2bb1 100644 --- a/src/proxy/socks5.rs +++ b/src/proxy/socks5.rs @@ -22,7 +22,6 @@ use tokio::io::AsyncWriteExt; use tokio::net::{TcpListener, TcpStream}; use tracing::{error, info, warn}; -use crate::proxy::connection_manager::ConnectionManager; use crate::proxy::outbound::OutboundConnection; use crate::proxy::{util, Error, ProxyInputs, TraceParent}; use crate::socket; @@ -31,7 +30,6 @@ pub(super) struct Socks5 { pi: ProxyInputs, listener: TcpListener, drain: Watch, - connection_manager: ConnectionManager, } impl Socks5 { @@ -51,7 +49,6 @@ impl Socks5 { pi, listener, drain, - connection_manager: ConnectionManager::new(), }) } @@ -63,7 +60,6 @@ impl Socks5 { let accept = async move { loop { // Asynchronously wait for an inbound socket. - let connection_manager = self.connection_manager.clone(); let socket = self.listener.accept().await; match socket { Ok((stream, remote)) => { @@ -71,7 +67,6 @@ impl Socks5 { let oc = OutboundConnection { pi: self.pi.clone(), id: TraceParent::new(), - connection_manager, }; tokio::spawn(async move { if let Err(err) = handle(oc, stream).await { diff --git a/src/proxyfactory.rs b/src/proxyfactory.rs index 8ba95c712..7ba255bf3 100644 --- a/src/proxyfactory.rs +++ b/src/proxyfactory.rs @@ -21,6 +21,7 @@ use tracing::error; use crate::dns; +use crate::proxy::connection_manager::ConnectionManager; use crate::proxy::{Error, Metrics}; use crate::proxy::Proxy; @@ -92,6 +93,7 @@ impl ProxyFactory { let pi = crate::proxy::ProxyInputs::new( self.config.clone(), self.cert_manager.clone(), + ConnectionManager::default(), self.state.clone(), self.proxy_metrics.clone().unwrap(), socket_factory.clone(),