Skip to content

Commit

Permalink
initial improvements for connection_manager (#804)
Browse files Browse the repository at this point in the history
* initial improvements for connection_manager

Signed-off-by: Ian Rudie <[email protected]>

* fix odd typo

Signed-off-by: Ian Rudie <[email protected]>

* revise policy_watcher_lifecycle test for new watcher

Signed-off-by: Ian Rudie <[email protected]>

* remove an unnecessary clone

Signed-off-by: Ian Rudie <[email protected]>

* move connection manager to proxy inputs; clippy-suggested refactor

Signed-off-by: ilrudie <[email protected]>

* cleanup

Signed-off-by: ilrudie <[email protected]>

---------

Signed-off-by: Ian Rudie <[email protected]>
Signed-off-by: ilrudie <[email protected]>
  • Loading branch information
ilrudie authored Feb 23, 2024
1 parent bcdce3f commit 03d4dae
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 89 deletions.
11 changes: 10 additions & 1 deletion src/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ pub use metrics::*;

use crate::identity::SecretManager;
use crate::metrics::Recorder;
use crate::proxy::connection_manager::{ConnectionManager, PolicyWatcher};
use crate::proxy::inbound_passthrough::InboundPassthrough;
use crate::proxy::outbound::Outbound;
use crate::proxy::socks5::Socks5;
Expand All @@ -41,7 +42,7 @@ use crate::state::workload::{network_addr, Workload};
use crate::state::DemandProxyState;
use crate::{config, identity, socket, tls};

mod connection_manager;
pub mod connection_manager;
mod inbound;
mod inbound_passthrough;
#[allow(non_camel_case_types)]
Expand Down Expand Up @@ -91,12 +92,14 @@ pub struct Proxy {
inbound_passthrough: InboundPassthrough,
outbound: Outbound,
socks5: Socks5,
policy_watcher: PolicyWatcher,
}

#[derive(Clone)]
pub(super) struct ProxyInputs {
cfg: config::Config,
cert_manager: Arc<SecretManager>,
connection_manager: ConnectionManager,
hbone_port: u16,
pub state: DemandProxyState,
metrics: Arc<Metrics>,
Expand All @@ -108,6 +111,7 @@ impl ProxyInputs {
pub fn new(
cfg: config::Config,
cert_manager: Arc<SecretManager>,
connection_manager: ConnectionManager,
state: DemandProxyState,
metrics: Arc<Metrics>,
socket_factory: Arc<dyn SocketFactory + Send + Sync>,
Expand All @@ -117,6 +121,7 @@ impl ProxyInputs {
state,
cert_manager,
metrics,
connection_manager,
pool: pool::Pool::new(),
hbone_port: 0,
socket_factory,
Expand All @@ -137,6 +142,7 @@ impl Proxy {
cfg,
state,
cert_manager,
connection_manager: ConnectionManager::default(),
metrics,
pool: pool::Pool::new(),
hbone_port: 0,
Expand All @@ -152,12 +158,14 @@ impl Proxy {
let inbound_passthrough = InboundPassthrough::new(pi.clone(), drain.clone()).await?;
let outbound = Outbound::new(pi.clone(), drain.clone()).await?;
let socks5 = Socks5::new(pi.clone(), drain.clone()).await?;
let policy_watcher = PolicyWatcher::new(pi.state, drain, pi.connection_manager);

Ok(Proxy {
inbound,
inbound_passthrough,
outbound,
socks5,
policy_watcher,
})
}

Expand All @@ -167,6 +175,7 @@ impl Proxy {
tokio::spawn(self.inbound.run().in_current_span()),
tokio::spawn(self.outbound.run().in_current_span()),
tokio::spawn(self.socks5.run().in_current_span()),
tokio::spawn(self.policy_watcher.run().in_current_span()),
];

futures::future::join_all(tasks).await;
Expand Down
94 changes: 53 additions & 41 deletions src/proxy/connection_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ use crate::state::DemandProxyState;
use drain;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::watch;
use tokio::sync::RwLock;
use tracing::info;

Expand Down Expand Up @@ -51,13 +50,15 @@ pub struct ConnectionManager {
drains: Arc<RwLock<HashMap<Connection, ConnectionDrain>>>,
}

impl ConnectionManager {
pub fn new() -> Self {
impl std::default::Default for ConnectionManager {
fn default() -> Self {
ConnectionManager {
drains: Arc::new(RwLock::new(HashMap::new())),
}
}
}

impl ConnectionManager {
// register a connection with the connection manager
// this must be done before a connection can be tracked
// allows policy to be asserted against the connection
Expand Down Expand Up @@ -119,24 +120,39 @@ impl ConnectionManager {
}
}

pub async fn policy_watcher(
pub struct PolicyWatcher {
state: DemandProxyState,
mut stop_rx: watch::Receiver<()>,
stop: drain::Watch,
connection_manager: ConnectionManager,
parent_proxy: &str,
) {
let mut policies_changed = state.read().policies.subscribe();
loop {
tokio::select! {
_ = stop_rx.changed() => {
break;
}
_ = policies_changed.changed() => {
let connections = connection_manager.connections().await;
for conn in connections {
if !state.assert_rbac(&conn).await {
connection_manager.close(&conn).await;
info!("{parent_proxy} connection {conn} closed because it's no longer allowed after a policy update");
}

impl PolicyWatcher {
pub fn new(
state: DemandProxyState,
stop: drain::Watch,
connection_manager: ConnectionManager,
) -> Self {
PolicyWatcher {
state,
stop,
connection_manager,
}
}

pub async fn run(self) {
let mut policies_changed = self.state.read().policies.subscribe();
loop {
tokio::select! {
_ = self.stop.clone().signaled() => {
break;
}
_ = policies_changed.changed() => {
let connections = self.connection_manager.connections().await;
for conn in connections {
if !self.state.assert_rbac(&conn).await {
self.connection_manager.close(&conn).await;
info!("connection {conn} closed because it's no longer allowed after a policy update");
}
}
}
}
Expand All @@ -151,19 +167,18 @@ mod tests {
use std::net::{Ipv4Addr, SocketAddrV4};
use std::sync::{Arc, RwLock};
use std::time::Duration;
use tokio::sync::watch;

use crate::rbac::Connection;
use crate::state::{DemandProxyState, ProxyState};
use crate::xds::istio::security::{Action, Authorization, Scope};
use crate::xds::ProxyStateUpdateMutator;

use super::ConnectionManager;
use super::{ConnectionManager, PolicyWatcher};

#[tokio::test]
async fn test_connection_manager_close() {
// setup a new ConnectionManager
let connection_manager = ConnectionManager::new();
let connection_manager = ConnectionManager::default();
// ensure drains is empty
assert_eq!(connection_manager.drains.read().await.len(), 0);
assert_eq!(connection_manager.connections().await.len(), 0);
Expand Down Expand Up @@ -252,7 +267,7 @@ mod tests {
#[tokio::test]
async fn test_connection_manager_release() {
// setup a new ConnectionManager
let connection_manager = ConnectionManager::new();
let connection_manager = ConnectionManager::default();
// ensure drains is empty
assert_eq!(connection_manager.drains.read().await.len(), 0);
assert_eq!(connection_manager.connections().await.len(), 0);
Expand Down Expand Up @@ -355,21 +370,17 @@ mod tests {
ResolverConfig::default(),
ResolverOpts::default(),
);
let connection_manager = ConnectionManager::new();
let parent_proxy = "test";
let (stop_tx, stop_rx) = watch::channel(());
let connection_manager = ConnectionManager::default();
let (tx, stop) = drain::channel();
let state_mutator = ProxyStateUpdateMutator::new_no_fetch();

// clones to move into spawned task
let ds = dstate.clone();
let cm = connection_manager.clone();
let pw = PolicyWatcher::new(ds, stop, cm);
// spawn a task which watches policy and asserts that the policy watcher stop correctly
tokio::spawn(async move {
let res = tokio::time::timeout(
Duration::from_secs(1),
super::policy_watcher(ds, stop_rx, cm, parent_proxy),
)
.await;
let res = tokio::time::timeout(Duration::from_secs(1), pw.run()).await;
assert!(res.is_ok())
});

Expand Down Expand Up @@ -399,18 +410,19 @@ mod tests {
// spawn an assertion that our connection close is received
tokio::spawn(assert_close(close1));

// update our state
let mut s = state
.write()
.expect("test fails if we're unable to get a write lock on state");
let res = state_mutator.insert_authorization(&mut s, auth);
// assert that the update was OK
assert!(res.is_ok());
// release lock
drop(s);
// this block will scope our guard appropriately
{
// update our state
let mut s = state
.write()
.expect("test fails if we're unable to get a write lock on state");
let res = state_mutator.insert_authorization(&mut s, auth);
// assert that the update was OK
assert!(res.is_ok());
} // release lock

// send the signal which stops policy watcher
stop_tx.send_replace(());
tx.drain().await;
}

// small helper to assert that the Watches are working in a timely manner
Expand Down
17 changes: 2 additions & 15 deletions src/proxy/inbound.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,9 @@ use hyper::body::Incoming;
use hyper::service::service_fn;
use hyper::{Method, Request, Response, StatusCode};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::watch;
use tracing::{debug, error, info, instrument, trace, trace_span, warn, Instrument};

use super::connection_manager::{self, ConnectionManager};
use super::connection_manager::ConnectionManager;
use super::{Error, SocketFactory};
use crate::baggage::parse_baggage_header;
use crate::config::Config;
Expand Down Expand Up @@ -80,7 +79,7 @@ impl Inbound {
metrics: pi.metrics,
drain,
socket_factory: pi.socket_factory.clone(),
connection_manager: ConnectionManager::new(),
connection_manager: pi.connection_manager,
})
}

Expand All @@ -99,17 +98,6 @@ impl Inbound {
let mut stream = stream.take_until(Box::pin(self.drain.signaled()));

let (sub_drain_signal, sub_drain) = drain::channel();
// spawn a task which subscribes to watch updates and asserts rbac against this proxy's connections, closing the ones which have become denied
let (stop_tx, stop_rx) = watch::channel(());
let state = self.state.clone();
let connection_manager = self.connection_manager.clone();

tokio::spawn(connection_manager::policy_watcher(
state,
stop_rx,
connection_manager,
"inbound",
));

while let Some(socket) = stream.next().await {
let state = self.state.clone();
Expand Down Expand Up @@ -163,7 +151,6 @@ impl Inbound {
});
}
info!("draining connections");
stop_tx.send_replace(()); // close the task handling auth updates
drop(sub_drain); // sub_drain_signal.drain() will never resolve while sub_drain is valid, will deadlock if not dropped
sub_drain_signal.drain().await;
info!("all inbound connections drained");
Expand Down
20 changes: 2 additions & 18 deletions src/proxy/inbound_passthrough.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,10 @@ use std::net::SocketAddr;

use drain::Watch;
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::watch;
use tracing::{error, info, trace, warn, Instrument};

use crate::config::ProxyMode;
use crate::proxy::connection_manager::{self, ConnectionManager};
use crate::proxy::connection_manager::ConnectionManager;
use crate::proxy::metrics::Reporter;
use crate::proxy::outbound::OutboundConnection;
use crate::proxy::{metrics, util, ProxyInputs};
Expand All @@ -33,7 +32,6 @@ pub(super) struct InboundPassthrough {
listener: TcpListener,
pi: ProxyInputs,
drain: Watch,
connection_manager: ConnectionManager,
}

impl InboundPassthrough {
Expand All @@ -60,29 +58,17 @@ impl InboundPassthrough {
listener,
pi,
drain,
connection_manager: ConnectionManager::new(),
})
}

pub(super) async fn run(self) {
// spawn a task which subscribes to watch updates and asserts rbac against this proxy's connections, closing the ones which have become denied
let (stop_tx, stop_rx) = watch::channel(());
let connection_manager = self.connection_manager.clone();
let state = self.pi.state.clone();

tokio::spawn(connection_manager::policy_watcher(
state,
stop_rx,
connection_manager,
"inbound_passthrough",
));
let accept = async move {
loop {
// Asynchronously wait for an inbound socket.
let socket = self.listener.accept().await;
let pi = self.pi.clone();

let connection_manager = self.connection_manager.clone();
let connection_manager = self.pi.connection_manager.clone();
match socket {
Ok((stream, remote)) => {
tokio::spawn(async move {
Expand Down Expand Up @@ -114,7 +100,6 @@ impl InboundPassthrough {
res = accept => { res }
_ = self.drain.signaled() => {
info!("inbound passthrough drained");
stop_tx.send_replace(());
}
}
}
Expand Down Expand Up @@ -149,7 +134,6 @@ impl InboundPassthrough {
let mut oc = OutboundConnection {
pi: pi.clone(),
id: TraceParent::new(),
connection_manager,
};
// Spoofing the source IP only works when the destination or the source are on our node.
// In this case, the source and the destination might both be remote, so we need to disable it.
Expand Down
Loading

0 comments on commit 03d4dae

Please sign in to comment.