From deceecf7a803c9e24484252f92f5ff23cf99767f Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Mon, 16 Oct 2023 16:17:42 -0700 Subject: [PATCH 1/3] Infra stall detection --- Cargo.toml | 8 +- src/helpers/buffers/mod.rs | 34 -- src/helpers/buffers/ordering_sender.rs | 23 ++ src/helpers/buffers/unordered_receiver.rs | 71 ++++- src/helpers/gateway/mod.rs | 79 +++-- src/helpers/gateway/receive.rs | 33 +- src/helpers/gateway/send.rs | 12 +- src/helpers/gateway/stall_detection.rs | 372 ++++++++++++++++++++++ src/helpers/gateway/transport.rs | 12 +- src/helpers/mod.rs | 34 +- src/helpers/prss_protocol.rs | 6 +- src/lib.rs | 20 +- src/protocol/step/compact.rs | 2 +- src/protocol/step/descriptive.rs | 2 +- 14 files changed, 602 insertions(+), 106 deletions(-) create mode 100644 src/helpers/gateway/stall_detection.rs diff --git a/Cargo.toml b/Cargo.toml index 1767e8e24..53bd0ceb6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,7 +12,8 @@ default = [ "tracing/max_level_trace", "tracing/release_max_level_info", "descriptive-gate", - "aggregate-circuit" + "aggregate-circuit", + "stall-detection", ] cli = ["comfy-table", "clap"] enable-serde = ["serde", "serde_json"] @@ -21,6 +22,10 @@ disable-metrics = [] # TODO Consider moving out benches as well web-app = ["axum", "axum-server", "base64", "clap", "comfy-table", "enable-serde", "hyper", "hyper-rustls", "rcgen", "rustls", "rustls-pemfile", "time", "tokio-rustls", "toml", "tower", "tower-http"] test-fixture = ["enable-serde", "weak-field"] +# Include observability instruments that detect lack of progress inside MPC. If there is a bug that leads to helper +# miscommunication, this feature helps to detect it. Turning it on has some cost. +# If "shuttle" feature is enabled, turning this on has no effect. +stall-detection = [] shuttle = ["shuttle-crate", "test-fixture"] debug-trace = ["tracing/max_level_trace", "tracing/release_max_level_debug"] # TODO: we may want to use in-memory-bench and real-world-bench some time after @@ -55,6 +60,7 @@ comfy-table = { version = "7.0", optional = true } config = "0.13.2" criterion = { version = "0.5.1", optional = true, default-features = false, features = ["async_tokio", "plotters", "html_reports"] } dashmap = "5.4" +delegate = "0.10.0" dhat = "0.3.2" embed-doc-image = "0.1.4" futures = "0.3.28" diff --git a/src/helpers/buffers/mod.rs b/src/helpers/buffers/mod.rs index 83a92c24e..943c884dd 100644 --- a/src/helpers/buffers/mod.rs +++ b/src/helpers/buffers/mod.rs @@ -5,37 +5,3 @@ mod unordered_receiver; pub use ordering_mpsc::{ordering_mpsc, OrderingMpscReceiver, OrderingMpscSender}; pub use ordering_sender::{OrderedStream, OrderingSender}; pub use unordered_receiver::UnorderedReceiver; - -#[cfg(debug_assertions)] -#[allow(unused)] // todo(alex): make test world print the state again -mod waiting { - use std::collections::HashMap; - - use crate::helpers::ChannelId; - - pub(in crate::helpers) struct WaitingTasks<'a> { - tasks: HashMap<&'a ChannelId, Vec>, - } - - impl<'a> WaitingTasks<'a> { - pub fn new(tasks: HashMap<&'a ChannelId, Vec>) -> Self { - Self { tasks } - } - - pub fn is_empty(&self) -> bool { - self.tasks.is_empty() - } - } - - impl std::fmt::Debug for WaitingTasks<'_> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "[")?; - for (channel, records) in &self.tasks { - write!(f, "\n {channel:?}: {records:?}")?; - } - write!(f, "\n]")?; - - Ok(()) - } - } -} diff --git a/src/helpers/buffers/ordering_sender.rs b/src/helpers/buffers/ordering_sender.rs index ea1459d62..5ecc21c7f 100644 --- a/src/helpers/buffers/ordering_sender.rs +++ b/src/helpers/buffers/ordering_sender.rs @@ -188,6 +188,11 @@ impl WaitingShard { self.wakers.pop_front().unwrap().w.wake(); } } + + #[cfg(feature = "stall-detection")] + pub fn waiting(&self) -> impl Iterator + '_ { + self.wakers.iter().map(|waker| waker.i) + } } /// A collection of wakers that are indexed by the send index (`i`). @@ -224,6 +229,19 @@ impl Waiting { fn wake(&self, i: usize) { self.shard(i).wake(i); } + + /// Returns all records currently waiting to be sent in sorted order. + #[cfg(feature = "stall-detection")] + fn waiting(&self) -> Vec { + let mut records = Vec::new(); + self.shards + .iter() + .for_each(|shard| records.extend(shard.lock().unwrap().waiting())); + + records.sort_unstable(); + + records + } } /// An `OrderingSender` accepts messages for sending in any order, but @@ -375,6 +393,11 @@ impl OrderingSender { ) -> OrderedStream> { OrderedStream { sender: self } } + + #[cfg(feature = "stall-detection")] + pub fn waiting(&self) -> Vec { + self.waiting.waiting() + } } /// A future for writing item `i` into an `OrderingSender`. diff --git a/src/helpers/buffers/unordered_receiver.rs b/src/helpers/buffers/unordered_receiver.rs index d578e6456..88de6996e 100644 --- a/src/helpers/buffers/unordered_receiver.rs +++ b/src/helpers/buffers/unordered_receiver.rs @@ -116,6 +116,8 @@ where stream: Pin>, /// The absolute index of the next value that will be received. next: usize, + /// The maximum value that has ever been requested to receive. + max_polled_idx: usize, /// The underlying stream can provide chunks of data larger than a single /// message. Save any spare data here. spare: Spare, @@ -143,6 +145,12 @@ where /// Note: in protocols we try to send before receiving, so we can rely on /// that easing load on this mechanism. There might also need to be some /// end-to-end back pressure for tasks that do not involve sending at all. + /// + /// If stall detection is enabled, the index of that waker is stored alongside with it, in order + /// to correctly identify the `i` awaiting completion + #[cfg(feature = "stall-detection")] + overflow_wakers: Vec<(Waker, usize)>, + #[cfg(not(feature = "stall-detection"))] overflow_wakers: Vec, _marker: PhantomData, } @@ -172,7 +180,11 @@ where ); // We don't save a waker at `self.next`, so `>` and not `>=`. if i > self.next + self.wakers.len() { - self.overflow_wakers.push(waker); + #[cfg(feature = "stall-detection")] + let overflow = (waker, i); + #[cfg(not(feature = "stall-detection"))] + let overflow = waker; + self.overflow_wakers.push(overflow); } else { let index = i % self.wakers.len(); if let Some(old) = self.wakers[index].as_ref() { @@ -195,6 +207,11 @@ where } if self.next % (self.wakers.len() / 2) == 0 { // Wake all the overflowed wakers. See comments on `overflow_wakers`. + #[cfg(feature = "stall-detection")] + for (w, _) in take(&mut self.overflow_wakers) { + w.wake(); + } + #[cfg(not(feature = "stall-detection"))] for w in take(&mut self.overflow_wakers) { w.wake(); } @@ -204,6 +221,7 @@ where /// Poll for the next record. This should only be invoked when /// the future for the next message is polled. fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll> { + self.max_polled_idx = std::cmp::max(self.max_polled_idx, self.next); if let Some(m) = self.spare.read() { self.wake_next(); return Poll::Ready(Ok(m)); @@ -228,6 +246,46 @@ where } } } + + #[cfg(feature = "stall-detection")] + fn waiting(&self) -> impl Iterator + '_ { + /// There is no waker for self.next and it could be advanced past the end of the stream. + /// This helps to conditionally add self.next to the waiting list. + struct MaybeNext { + currently_at: usize, + next: usize, + } + impl Iterator for MaybeNext { + type Item = usize; + + fn next(&mut self) -> Option { + if self.currently_at == self.next { + self.currently_at += 1; + Some(self.next) + } else { + None + } + } + } + + let start = self.next % self.wakers.len(); + self.wakers + .iter() + .enumerate() + .filter_map(|(i, waker)| waker.as_ref().map(|_| i)) + .map(move |i| { + if i < start { + self.next + (self.wakers.len() - start + i) + } else { + self.next + (i - start) + } + }) + .chain(self.overflow_wakers.iter().map(|v| v.1)) + .chain(MaybeNext { + currently_at: self.max_polled_idx, + next: self.next, + }) + } } /// Take an ordered stream of bytes and make messages from that stream @@ -262,6 +320,7 @@ where inner: Arc::new(Mutex::new(OperatingState { stream, next: 0, + max_polled_idx: 0, spare: Spare::default(), wakers, overflow_wakers: Vec::new(), @@ -284,6 +343,16 @@ where _marker: PhantomData, } } + + #[cfg(feature = "stall-detection")] + pub fn waiting(&self) -> Vec { + let state = self.inner.lock().unwrap(); + let mut r = state.waiting().collect::>(); + + r.sort_unstable(); + + r + } } impl Clone for UnorderedReceiver diff --git a/src/helpers/gateway/mod.rs b/src/helpers/gateway/mod.rs index c2a37916c..901783ea6 100644 --- a/src/helpers/gateway/mod.rs +++ b/src/helpers/gateway/mod.rs @@ -1,19 +1,22 @@ mod receive; mod send; +#[cfg(feature = "stall-detection")] +pub(super) mod stall_detection; mod transport; -use std::{fmt::Debug, num::NonZeroUsize}; +use std::num::NonZeroUsize; -pub use send::SendingEnd; -#[cfg(all(feature = "shuttle", test))] +pub(super) use receive::ReceivingEnd; +pub(super) use send::SendingEnd; +#[cfg(all(test, feature = "shuttle"))] use shuttle::future as tokio; +#[cfg(feature = "stall-detection")] +pub(super) use stall_detection::InstrumentedGateway; use crate::{ helpers::{ gateway::{ - receive::{GatewayReceivers, ReceivingEnd as ReceivingEndBase}, - send::GatewaySenders, - transport::RoleResolvingTransport, + receive::GatewayReceivers, send::GatewaySenders, transport::RoleResolvingTransport, }, ChannelId, Message, Role, RoleAssignment, TotalRecords, Transport, }, @@ -31,18 +34,21 @@ pub type TransportImpl = super::transport::InMemoryTransport; pub type TransportImpl = crate::sync::Arc; pub type TransportError = ::Error; -pub type ReceivingEnd = ReceivingEndBase; -/// Gateway into IPA Infrastructure systems. This object allows sending and receiving messages. -/// As it is generic over network/transport layer implementation, type alias [`Gateway`] should be -/// used to avoid carrying `T` over. -/// -/// [`Gateway`]: crate::helpers::Gateway -pub struct Gateway { +/// Gateway into IPA Network infrastructure. It allows helpers send and receive messages. +pub struct Gateway { config: GatewayConfig, - transport: RoleResolvingTransport, + transport: RoleResolvingTransport, + #[cfg(feature = "stall-detection")] + inner: crate::sync::Arc, + #[cfg(not(feature = "stall-detection"))] + inner: State, +} + +#[derive(Default)] +pub struct State { senders: GatewaySenders, - receivers: GatewayReceivers, + receivers: GatewayReceivers, } #[derive(Clone, Copy, Debug)] @@ -50,16 +56,23 @@ pub struct GatewayConfig { /// The number of items that can be active at the one time. /// This is used to determine the size of sending and receiving buffers. active: NonZeroUsize, + + /// Time to wait before checking gateway progress. If no progress has been made between + /// checks, the gateway is considered to be stalled and will create a report with outstanding + /// send/receive requests + #[cfg(feature = "stall-detection")] + pub progress_check_interval: std::time::Duration, } -impl Gateway { +impl Gateway { #[must_use] pub fn new( query_id: QueryId, config: GatewayConfig, roles: RoleAssignment, - transport: T, + transport: TransportImpl, ) -> Self { + #[allow(clippy::useless_conversion)] Self { config, transport: RoleResolvingTransport { @@ -68,8 +81,7 @@ impl Gateway { inner: transport, config, }, - senders: GatewaySenders::default(), - receivers: GatewayReceivers::default(), + inner: State::default().into(), } } @@ -91,10 +103,12 @@ impl Gateway { &self, channel_id: &ChannelId, total_records: TotalRecords, - ) -> SendingEnd { - let (tx, maybe_stream) = - self.senders - .get_or_create::(channel_id, self.config.active_work(), total_records); + ) -> send::SendingEnd { + let (tx, maybe_stream) = self.inner.senders.get_or_create::( + channel_id, + self.config.active_work(), + total_records, + ); if let Some(stream) = maybe_stream { tokio::spawn({ let channel_id = channel_id.clone(); @@ -109,14 +123,15 @@ impl Gateway { }); } - SendingEnd::new(tx, self.role(), channel_id) + send::SendingEnd::new(tx, self.role(), channel_id) } #[must_use] - pub fn get_receiver(&self, channel_id: &ChannelId) -> ReceivingEndBase { - ReceivingEndBase::new( + pub fn get_receiver(&self, channel_id: &ChannelId) -> receive::ReceivingEnd { + receive::ReceivingEnd::new( channel_id.clone(), - self.receivers + self.inner + .receivers .get_or_create(channel_id, || self.transport.receive(channel_id)), ) } @@ -135,8 +150,18 @@ impl GatewayConfig { /// If `active` is 0. #[must_use] pub fn new(active: usize) -> Self { + // In-memory tests are fast, so progress check intervals can be lower. + // Real world scenarios currently over-report stalls because of inefficiencies inside + // infrastructure and actual networking issues. This checks is only valuable to report + // bugs, so keeping it large enough to avoid false positives. Self { active: NonZeroUsize::new(active).unwrap(), + #[cfg(feature = "stall-detection")] + progress_check_interval: std::time::Duration::from_secs(if cfg!(test) { + 5 + } else { + 30 + }), } } diff --git a/src/helpers/gateway/receive.rs b/src/helpers/gateway/receive.rs index 282ff68e2..c30d285ed 100644 --- a/src/helpers/gateway/receive.rs +++ b/src/helpers/gateway/receive.rs @@ -4,29 +4,30 @@ use dashmap::{mapref::entry::Entry, DashMap}; use futures::Stream; use crate::{ - helpers::{buffers::UnorderedReceiver, ChannelId, Error, Message, Transport}, + helpers::{buffers::UnorderedReceiver, ChannelId, Error, Message, Transport, TransportImpl}, protocol::RecordId, }; /// Receiving end end of the gateway channel. -pub struct ReceivingEnd { +pub struct ReceivingEnd { channel_id: ChannelId, - unordered_rx: UR, + unordered_rx: UR, _phantom: PhantomData, } /// Receiving channels, indexed by (role, step). -pub(super) struct GatewayReceivers { - inner: DashMap>, +#[derive(Default)] +pub(super) struct GatewayReceivers { + pub(super) inner: DashMap, } -pub(super) type UR = UnorderedReceiver< - ::RecordsStream, - <::RecordsStream as Stream>::Item, +pub(super) type UR = UnorderedReceiver< + ::RecordsStream, + <::RecordsStream as Stream>::Item, >; -impl ReceivingEnd { - pub(super) fn new(channel_id: ChannelId, rx: UR) -> Self { +impl ReceivingEnd { + pub(super) fn new(channel_id: ChannelId, rx: UR) -> Self { Self { channel_id, unordered_rx: rx, @@ -55,16 +56,8 @@ impl ReceivingEnd { } } -impl Default for GatewayReceivers { - fn default() -> Self { - Self { - inner: DashMap::default(), - } - } -} - -impl GatewayReceivers { - pub fn get_or_create UR>(&self, channel_id: &ChannelId, ctr: F) -> UR { +impl GatewayReceivers { + pub fn get_or_create UR>(&self, channel_id: &ChannelId, ctr: F) -> UR { // TODO: raw entry API if it becomes available to avoid cloning the key match self.inner.entry(channel_id.clone()) { Entry::Occupied(entry) => entry.get().clone(), diff --git a/src/helpers/gateway/send.rs b/src/helpers/gateway/send.rs index 4eb876af0..4f07c5151 100644 --- a/src/helpers/gateway/send.rs +++ b/src/helpers/gateway/send.rs @@ -30,7 +30,7 @@ pub struct SendingEnd { /// Sending channels, indexed by (role, step). #[derive(Default)] pub(super) struct GatewaySenders { - inner: DashMap>, + pub(super) inner: DashMap>, } pub(super) struct GatewaySender { @@ -77,6 +77,16 @@ impl GatewaySender { Ok(()) } + + #[cfg(feature = "stall-detection")] + pub fn waiting(&self) -> Vec { + self.ordering_tx.waiting() + } + + #[cfg(feature = "stall-detection")] + pub fn total_records(&self) -> TotalRecords { + self.total_records + } } impl SendingEnd { diff --git a/src/helpers/gateway/stall_detection.rs b/src/helpers/gateway/stall_detection.rs new file mode 100644 index 000000000..2803e5aed --- /dev/null +++ b/src/helpers/gateway/stall_detection.rs @@ -0,0 +1,372 @@ +use std::{ + fmt::{Debug, Display, Formatter}, + ops::{RangeInclusive, Sub}, +}; + +pub use gateway::InstrumentedGateway; + +use crate::sync::{ + atomic::{AtomicUsize, Ordering}, + Weak, +}; + +/// Trait for structs that can report their current state. +pub trait ObserveState { + type State: Debug; + fn get_state(&self) -> Option; +} + +/// This object does not own the sequence number, it must be stored outside and dropped when +/// observing entity goes out of scope. If that happens, any attempt to increment it through this +/// instance will result in a panic. +/// +/// Observing and incrementing sequence numbers do not introduce happens-before relationship. +pub struct Observed { + /// Each time a state change occurs inside the observable object `T`, its sequence number is + /// incremented by 1. It is up to the caller to decide what is a state change. + /// + /// The sequence number is stored as a weak reference, so it can be dropped when the observed + /// object is dropped. + /// + /// External observers watching this object will declare it stalled if it's sequence number + /// hasn't been incremented for long enough time. It can happen for two reasons: either there is + /// no work to do for this object, or its state is not drained/consumed by the clients. In the + /// former case, the bottleneck is somewhere else, otherwise if `T` implements `ObserveState`, + /// the current state of `T` is also reported. + sn: Weak, + inner: T, +} + +impl Observed { + fn wrap(sn: Weak, inner: T) -> Self { + Self { sn, inner } + } + + fn get_sn(&self) -> &Weak { + &self.sn + } + + /// Advances the sequence number ahead. + /// + /// ## Panics + /// This will panic if the sequence number is dropped. + fn advance(&self) { + let sn = self.sn.upgrade().unwrap(); + sn.fetch_add(1, Ordering::Relaxed); + } + + fn inner(&self) -> &T { + &self.inner + } +} + +impl Observed { + pub fn get_state(&self) -> Option { + self.inner().get_state() + } +} + +mod gateway { + use delegate::delegate; + + use super::*; + use crate::{ + helpers::{ + gateway::{Gateway, State}, + ChannelId, GatewayConfig, Message, ReceivingEnd, Role, RoleAssignment, SendingEnd, + TotalRecords, TransportImpl, + }, + protocol::QueryId, + sync::Arc, + }; + + pub struct InstrumentedGateway { + gateway: Gateway, + // Gateway owns the sequence number associated with it. When it goes out of scope, sn is destroyed + // and external observers can see that they no longer need to watch it. + _sn: Arc, + } + + impl Observed { + delegate! { + to self.inner().gateway { + + #[inline] + pub fn role(&self) -> Role; + + #[inline] + pub fn config(&self) -> &GatewayConfig; + } + } + + #[allow(clippy::let_and_return)] + pub fn new( + query_id: QueryId, + config: GatewayConfig, + roles: RoleAssignment, + transport: TransportImpl, + ) -> Self { + let version = Arc::new(AtomicUsize::default()); + let r = Self::wrap( + Arc::downgrade(&version), + InstrumentedGateway { + gateway: Gateway::new(query_id, config, roles, transport), + _sn: version, + }, + ); + + // spawn the watcher + #[cfg(not(feature = "shuttle"))] + { + use tracing::Instrument; + + tokio::spawn({ + let gateway = r.to_observed(); + async move { + let mut last_sn_seen = 0; + loop { + ::tokio::time::sleep(config.progress_check_interval).await; + let now = gateway.get_sn().upgrade().map(|v| v.load(Ordering::Relaxed)); + if let Some(now) = now { + if now == last_sn_seen { + if let Some(state) = gateway.get_state() { + tracing::warn!(sn = now, state = ?state, "Helper is stalled"); + } + } + last_sn_seen = now; + } else { + break; + } + } + }.instrument(tracing::info_span!("stall_detector", role = ?r.role())) + }); + } + + r + } + + #[must_use] + pub fn get_sender( + &self, + channel_id: &ChannelId, + total_records: TotalRecords, + ) -> SendingEnd { + Observed::wrap( + Weak::clone(self.get_sn()), + self.inner().gateway.get_sender(channel_id, total_records), + ) + } + + #[must_use] + pub fn get_receiver(&self, channel_id: &ChannelId) -> ReceivingEnd { + Observed::wrap( + Weak::clone(self.get_sn()), + self.inner().gateway.get_receiver(channel_id), + ) + } + + pub fn to_observed(&self) -> Observed> { + // todo: inner.inner + Observed::wrap( + Weak::clone(self.get_sn()), + Arc::downgrade(&self.inner().gateway.inner), + ) + } + } + + pub struct GatewayWaitingTasks { + senders_state: Option, + receivers_state: Option, + } + + impl Debug for GatewayWaitingTasks { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + if let Some(senders_state) = &self.senders_state { + write!(f, "\n{{{senders_state:?}\n}}")?; + } + if let Some(receivers_state) = &self.receivers_state { + write!(f, "\n{{{receivers_state:?}\n}}")?; + } + + Ok(()) + } + } + + impl ObserveState for Weak { + type State = GatewayWaitingTasks; + + fn get_state(&self) -> Option { + self.upgrade().and_then(|state| { + match (state.senders.get_state(), state.receivers.get_state()) { + (None, None) => None, + (senders_state, receivers_state) => Some(Self::State { + senders_state, + receivers_state, + }), + } + }) + } + } +} + +mod receive { + use std::{ + collections::BTreeMap, + fmt::{Debug, Formatter}, + }; + + use super::*; + use crate::{ + helpers::{ + error::Error, + gateway::{receive::GatewayReceivers, ReceivingEnd}, + ChannelId, Message, + }, + protocol::RecordId, + }; + + impl Observed> { + delegate::delegate! { + to { self.advance(); self.inner() } { + #[inline] + pub async fn receive(&self, record_id: RecordId) -> Result; + } + } + } + + pub struct WaitingTasks(BTreeMap>); + + impl Debug for WaitingTasks { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + for (channel, records) in &self.0 { + write!( + f, + "\n\"{:?}\", from={:?}. Waiting to receive records {:?}.", + channel.gate, channel.role, records + )?; + } + + Ok(()) + } + } + + impl ObserveState for GatewayReceivers { + type State = WaitingTasks; + + fn get_state(&self) -> Option { + let mut map = BTreeMap::default(); + for entry in &self.inner { + let channel = entry.key(); + if let Some(waiting) = super::to_ranges(entry.value().waiting()).get_state() { + map.insert(channel.clone(), waiting); + } + } + + (!map.is_empty()).then_some(WaitingTasks(map)) + } + } +} + +mod send { + use std::{ + collections::BTreeMap, + fmt::{Debug, Formatter}, + }; + + use super::*; + use crate::{ + helpers::{ + error::Error, + gateway::send::{GatewaySender, GatewaySenders}, + ChannelId, Message, TotalRecords, + }, + protocol::RecordId, + }; + + impl Observed> { + delegate::delegate! { + to { self.advance(); self.inner() } { + #[inline] + pub async fn send(&self, record_id: RecordId, msg: M) -> Result<(), Error>; + } + } + } + + pub struct WaitingTasks(BTreeMap)>); + + impl Debug for WaitingTasks { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + for (channel, (total, records)) in &self.0 { + write!( + f, + "\n\"{:?}\", to={:?}. Waiting to send records {:?} out of {total:?}.", + channel.gate, channel.role, records + )?; + } + + Ok(()) + } + } + + impl ObserveState for GatewaySenders { + type State = WaitingTasks; + + fn get_state(&self) -> Option { + let mut state = BTreeMap::new(); + for entry in &self.inner { + let channel = entry.key(); + let sender = entry.value(); + if let Some(sender_state) = sender.get_state() { + state.insert(channel.clone(), (sender.total_records(), sender_state)); + } + } + + (!state.is_empty()).then_some(WaitingTasks(state)) + } + } + + impl ObserveState for GatewaySender { + type State = Vec; + + fn get_state(&self) -> Option { + let waiting_indices = self.waiting(); + super::to_ranges(waiting_indices).get_state() + } + } +} + +/// Converts a vector of numbers into a vector of ranges. +/// For example, [1, 2, 3, 4, 5, 7, 9, 10, 11] produces [(1..=5), (7..=7), (9..=11)]. +fn to_ranges(nums: Vec) -> Vec> { + nums.into_iter() + .fold(Vec::>::new(), |mut ranges, num| { + if let Some(last_range) = ranges.last_mut().filter(|r| *r.end() == num - 1) { + *last_range = *last_range.start()..=num; + } else { + ranges.push(num..=num); + } + ranges + }) +} + +/// Range formatter that prints one-element wide ranges as single numbers. +impl ObserveState for Vec> +where + U: Copy + Display + Eq + PartialOrd + Ord + Sub + From, +{ + type State = Vec; + fn get_state(&self) -> Option { + let r = self + .iter() + .map( + |range| match (*range.end() - *range.start()).cmp(&U::from(1)) { + std::cmp::Ordering::Less => format!("{}", range.start()), + std::cmp::Ordering::Equal => format!("[{}, {}]", range.start(), range.end()), + std::cmp::Ordering::Greater => format!("[{}..{}]", range.start(), range.end()), + }, + ) + .collect::>(); + + (!r.is_empty()).then_some(r) + } +} diff --git a/src/helpers/gateway/transport.rs b/src/helpers/gateway/transport.rs index 94563b3c0..8c90a29ee 100644 --- a/src/helpers/gateway/transport.rs +++ b/src/helpers/gateway/transport.rs @@ -2,7 +2,7 @@ use crate::{ helpers::{ buffers::UnorderedReceiver, gateway::{receive::UR, send::GatewaySendStream}, - ChannelId, GatewayConfig, Role, RoleAssignment, RouteId, Transport, + ChannelId, GatewayConfig, Role, RoleAssignment, RouteId, Transport, TransportImpl, }, protocol::QueryId, }; @@ -12,19 +12,19 @@ use crate::{ /// /// [`HelperIdentity`]: crate::helpers::HelperIdentity #[derive(Clone)] -pub(super) struct RoleResolvingTransport { +pub(super) struct RoleResolvingTransport { pub query_id: QueryId, pub roles: RoleAssignment, pub config: GatewayConfig, - pub inner: T, + pub inner: TransportImpl, } -impl RoleResolvingTransport { +impl RoleResolvingTransport { pub(crate) async fn send( &self, channel_id: &ChannelId, data: GatewaySendStream, - ) -> Result<(), T::Error> { + ) -> Result<(), ::Error> { let dest_identity = self.roles.identity(channel_id.role); assert_ne!( dest_identity, @@ -41,7 +41,7 @@ impl RoleResolvingTransport { .await } - pub(crate) fn receive(&self, channel_id: &ChannelId) -> UR { + pub(crate) fn receive(&self, channel_id: &ChannelId) -> UR { let peer = self.roles.identity(channel_id.role); assert_ne!( peer, diff --git a/src/helpers/mod.rs b/src/helpers/mod.rs index 373070736..678202880 100644 --- a/src/helpers/mod.rs +++ b/src/helpers/mod.rs @@ -3,6 +3,8 @@ use std::{ num::NonZeroUsize, }; +use generic_array::GenericArray; + mod buffers; mod error; mod gateway; @@ -15,11 +17,33 @@ use std::ops::{Index, IndexMut}; #[cfg(test)] pub use buffers::OrderingSender; pub use error::{Error, Result}; + +#[cfg(feature = "stall-detection")] +mod gateway_exports { + use crate::helpers::{ + gateway, + gateway::{stall_detection::Observed, InstrumentedGateway}, + }; + + pub type Gateway = Observed; + pub type SendingEnd = Observed>; + pub type ReceivingEnd = Observed>; +} + +#[cfg(not(feature = "stall-detection"))] +mod gateway_exports { + use crate::helpers::gateway; + + pub type Gateway = gateway::Gateway; + pub type SendingEnd = gateway::SendingEnd; + pub type ReceivingEnd = gateway::ReceivingEnd; +} + +pub use gateway::GatewayConfig; // TODO: this type should only be available within infra. Right now several infra modules // are exposed at the root level. That makes it impossible to have a proper hierarchy here. -pub use gateway::{Gateway, TransportError, TransportImpl}; -pub use gateway::{GatewayConfig, ReceivingEnd, SendingEnd}; -use generic_array::GenericArray; +pub use gateway::{TransportError, TransportImpl}; +pub use gateway_exports::{Gateway, ReceivingEnd, SendingEnd}; pub use prss_protocol::negotiate as negotiate_prss; #[cfg(feature = "web-app")] pub use transport::WrappedAxumBodyStream; @@ -195,7 +219,7 @@ impl IndexMut for Vec { /// may be `H2` or `H3`. /// Each helper instance must be able to take any role, but once the role is assigned, it cannot /// be changed for the remainder of the query. -#[derive(Copy, Clone, Debug, PartialEq, Hash, Eq)] +#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] #[cfg_attr(feature = "cli", derive(clap::ValueEnum))] #[cfg_attr( feature = "enable-serde", @@ -384,7 +408,7 @@ impl TryFrom<[Role; 3]> for RoleAssignment { /// Combination of helper role and step that uniquely identifies a single channel of communication /// between two helpers. -#[derive(Clone, Eq, PartialEq, Hash)] +#[derive(Clone, Eq, PartialEq, Hash, Ord, PartialOrd)] pub struct ChannelId { pub role: Role, // TODO: step could be either reference or owned value. references are convenient to use inside diff --git a/src/helpers/prss_protocol.rs b/src/helpers/prss_protocol.rs index e44fec8ed..4dddd21fb 100644 --- a/src/helpers/prss_protocol.rs +++ b/src/helpers/prss_protocol.rs @@ -3,7 +3,7 @@ use rand_core::{CryptoRng, RngCore}; use x25519_dalek::PublicKey; use crate::{ - helpers::{ChannelId, Direction, Error, Gateway, TotalRecords, Transport}, + helpers::{ChannelId, Direction, Error, Gateway, TotalRecords}, protocol::{ prss, step::{Gate, Step, StepNarrow}, @@ -24,8 +24,8 @@ impl Step for PrssExchangeStep {} /// establish the prss endpoint by exchanging public keys with the other helpers /// # Errors /// if communication with other helpers fails -pub async fn negotiate( - gateway: &Gateway, +pub async fn negotiate( + gateway: &Gateway, gate: &Gate, rng: &mut R, ) -> Result { diff --git a/src/lib.rs b/src/lib.rs index 340601adc..17e78fcfa 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -124,13 +124,21 @@ pub(crate) mod test_executor { } } -#[cfg(all(feature = "in-memory-infra", feature = "real-world-infra"))] -compile_error!("feature \"in-memory-infra\" and feature \"real-world-infra\" cannot be enabled at the same time"); +macro_rules! mutually_incompatible { + ($feature1:literal,$feature2:literal) => { + #[cfg(all(feature = $feature1, feature = $feature2))] + compile_error!(concat!( + "feature \"", + $feature1, + "\" and feature \"", + $feature2, + "\" can't be enabled at the same time" + )); + }; +} -#[cfg(all(feature = "compact-gate", feature = "descriptive-date"))] -compile_error!( - "feature \"compact-gate\" and feature \"descriptive-gate\" cannot be enabled at the same time" -); +mutually_incompatible!("in-memory-infra", "real-world-infra"); +mutually_incompatible!("compact-gate", "descriptive-gate"); #[cfg(all(not(feature = "compact-gate"), not(feature = "descriptive-gate")))] compile_error!("feature \"compact-gate\" or \"descriptive-gate\" must be enabled"); diff --git a/src/protocol/step/compact.rs b/src/protocol/step/compact.rs index 585ce2042..e65abf737 100644 --- a/src/protocol/step/compact.rs +++ b/src/protocol/step/compact.rs @@ -5,7 +5,7 @@ use ipa_macros::Gate; use super::StepNarrow; use crate::helpers::{prss_protocol::PrssExchangeStep, query::QueryType}; -#[derive(Gate, Clone, Hash, PartialEq, Eq, Default)] +#[derive(Gate, Clone, Hash, PartialEq, Eq, PartialOrd, Ord, Default)] #[cfg_attr( feature = "enable-serde", derive(serde::Deserialize), diff --git a/src/protocol/step/descriptive.rs b/src/protocol/step/descriptive.rs index 4f41a4881..dc13e40a1 100644 --- a/src/protocol/step/descriptive.rs +++ b/src/protocol/step/descriptive.rs @@ -22,7 +22,7 @@ use crate::telemetry::{labels::STEP, metrics::STEP_NARROWED}; /// Step "a" would be executed with a context identifier of "protocol/a", which it /// would `narrow()` into "protocol/a/x" and "protocol/a/y" to produce a final set /// of identifiers: ".../a/x", ".../a/y", ".../b", and ".../c". -#[derive(Clone, Hash, PartialEq, Eq)] +#[derive(Clone, Hash, PartialEq, Eq, PartialOrd, Ord)] #[cfg_attr( feature = "enable-serde", derive(serde::Deserialize), From 0b23471e201f49db328616d8ae3c74ad9017f3b5 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Fri, 3 Nov 2023 09:36:29 -0700 Subject: [PATCH 2/3] Fix formatting --- src/helpers/gateway/stall_detection.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/helpers/gateway/stall_detection.rs b/src/helpers/gateway/stall_detection.rs index 2803e5aed..236b09700 100644 --- a/src/helpers/gateway/stall_detection.rs +++ b/src/helpers/gateway/stall_detection.rs @@ -179,7 +179,7 @@ mod gateway { receivers_state: Option, } - impl Debug for GatewayWaitingTasks { + impl Debug for GatewayWaitingTasks { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { if let Some(senders_state) = &self.senders_state { write!(f, "\n{{{senders_state:?}\n}}")?; From 88d208d870809ca14a21af0e2cd5febbb3f1fb12 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Mon, 6 Nov 2023 12:57:29 -0800 Subject: [PATCH 3/3] Update src/helpers/gateway/mod.rs Co-authored-by: Andy Leiserson --- src/helpers/gateway/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/helpers/gateway/mod.rs b/src/helpers/gateway/mod.rs index 901783ea6..4fd839ca5 100644 --- a/src/helpers/gateway/mod.rs +++ b/src/helpers/gateway/mod.rs @@ -72,7 +72,7 @@ impl Gateway { roles: RoleAssignment, transport: TransportImpl, ) -> Self { - #[allow(clippy::useless_conversion)] + #[allow(clippy::useless_conversion)] // not useless in stall-detection build Self { config, transport: RoleResolvingTransport {