diff --git a/Makefile.core.mk b/Makefile.core.mk index 0fd5c0ef4..a2d86ed9a 100644 --- a/Makefile.core.mk +++ b/Makefile.core.mk @@ -11,9 +11,6 @@ endif test: RUST_BACKTRACE=1 cargo test --benches --tests --bins $(FEATURES) -test.root: - CARGO_TARGET=`rustc -vV | sed -n 's|host: ||p' | tr [:lower:] [:upper:]| tr - _`_RUNNER='sudo -E' RUST_BACKTRACE=1 cargo test --benches --tests --bins $(FEATURES) - build: cargo build $(FEATURES) diff --git a/src/proxy.rs b/src/proxy.rs index 4215c430e..5dd9fc147 100644 --- a/src/proxy.rs +++ b/src/proxy.rs @@ -530,13 +530,17 @@ pub fn guess_inbound_service( .map(ServiceDescription::from) } -// Checks if the connection's source identity is the identity for the upstream's waypoint +// Checks that the source identiy and address match the upstream's waypoint async fn check_from_waypoint( state: DemandProxyState, upstream: &Workload, src_identity: Option<&Identity>, + src_ip: &IpAddr, ) -> bool { - check_gateway_address(state, src_identity, upstream.waypoint.as_ref()).await + let is_waypoint = |wl: &Workload| { + Some(wl.identity()).as_ref() == src_identity && wl.workload_ips.contains(src_ip) + }; + check_gateway_address(state, upstream.waypoint.as_ref(), is_waypoint).await } // Checks if the connection's source identity is the identity for the upstream's network @@ -546,42 +550,39 @@ async fn check_from_network_gateway( upstream: &Workload, src_identity: Option<&Identity>, ) -> bool { - check_gateway_address(state, src_identity, upstream.network_gateway.as_ref()).await + let is_gateway = |wl: &Workload| Some(wl.identity()).as_ref() == src_identity; + check_gateway_address(state, upstream.network_gateway.as_ref(), is_gateway).await } // Check if the source's identity matches any workloads that make up the given gateway // TODO: This can be made more accurate by also checking addresses. -async fn check_gateway_address( +async fn check_gateway_address( state: DemandProxyState, - src_identity: Option<&Identity>, gateway_address: Option<&GatewayAddress>, -) -> bool { - let Some(src_identity) = src_identity else { + predicate: F, +) -> bool +where + F: Fn(&Workload) -> bool, +{ + let Some(gateway_address) = gateway_address else { return false; }; - if let Some(gateway_address) = gateway_address { - let from_gateway = match state.fetch_destination(&gateway_address.destination).await { - Some(Address::Workload(wl)) => &wl.identity() == src_identity, - Some(Address::Service(svc)) => { - for (_ep_uid, ep) in svc.endpoints.iter() { - // fetch workloads by workload UID since we may not have an IP for an endpoint (e.g., endpoint is just a hostname) - if state - .fetch_workload_by_uid(&ep.workload_uid) - .await - .map(|w| w.identity()) - .as_ref() - == Some(src_identity) - { - return true; - } + + match state.fetch_destination(&gateway_address.destination).await { + Some(Address::Workload(wl)) => return predicate(&wl), + Some(Address::Service(svc)) => { + for (_ep_uid, ep) in svc.endpoints.iter() { + // fetch workloads by workload UID since we may not have an IP for an endpoint (e.g., endpoint is just a hostname) + let wl = state.fetch_workload_by_uid(&ep.workload_uid).await; + if wl.as_ref().is_some_and(&predicate) { + return true; } - false } - None => false, - }; - return from_gateway; - } - false // this occurs if gateway_address was None + } + None => {} + }; + + false } #[cfg(test)] diff --git a/src/proxy/inbound.rs b/src/proxy/inbound.rs index 15ee4e13c..d12e584c2 100644 --- a/src/proxy/inbound.rs +++ b/src/proxy/inbound.rs @@ -312,6 +312,7 @@ impl Inbound { pi.state.clone(), &upstream, conn.src_identity.as_ref(), + &conn.src_ip, ) .await; let from_gateway = proxy::check_from_network_gateway( @@ -504,11 +505,16 @@ impl Inbound { Some(match target_waypoint { Address::Service(svc) => { if !svc.contains_endpoint(&conn_wl, Some(connection_dst)) { + // target points to a different waypoint return Some(None); } Some((conn_wl, vec![*svc])) } Address::Workload(wl) => { + if !wl.workload_ips.contains(&conn.dst.ip()) { + // target points to a different waypoint + return Some(None); + } let svc = state.services.get_by_workload(&wl); Some((*wl, svc)) } diff --git a/src/proxy/outbound.rs b/src/proxy/outbound.rs index c60865a0b..a7e70a7af 100644 --- a/src/proxy/outbound.rs +++ b/src/proxy/outbound.rs @@ -470,12 +470,11 @@ impl OutboundConnection { ) .await?; - // TODO src_id may not be enough; should also check addresses/uid - let src_id = Some(source_workload.identity()); let from_waypoint = proxy::check_from_waypoint( self.pi.state.clone(), &mutable_us.workload, - src_id.as_ref(), + Some(&source_workload.identity()), + &downstream_network_addr.address, ) .await;