Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Instrumenting infrastructure with stall detection #832

Merged
merged 4 commits into from
Nov 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ default = [
"tracing/release_max_level_info",
"descriptive-gate",
"aggregate-circuit",
"ipa-prf"
"stall-detection",
"aggregate-circuit",
"ipa-prf",
]
cli = ["comfy-table", "clap"]
enable-serde = ["serde", "serde_json"]
Expand All @@ -22,6 +24,10 @@ disable-metrics = []
# TODO Consider moving out benches as well
web-app = ["axum", "axum-server", "base64", "clap", "comfy-table", "enable-serde", "hyper", "hyper-rustls", "rcgen", "rustls", "rustls-pemfile", "time", "tokio-rustls", "toml", "tower", "tower-http"]
test-fixture = ["enable-serde", "weak-field"]
# Include observability instruments that detect lack of progress inside MPC. If there is a bug that leads to helper
# miscommunication, this feature helps to detect it. Turning it on has some cost.
# If "shuttle" feature is enabled, turning this on has no effect.
stall-detection = []
shuttle = ["shuttle-crate", "test-fixture"]
debug-trace = ["tracing/max_level_trace", "tracing/release_max_level_debug"]
# TODO: we may want to use in-memory-bench and real-world-bench some time after
Expand Down Expand Up @@ -59,6 +65,7 @@ config = "0.13.2"
criterion = { version = "0.5.1", optional = true, default-features = false, features = ["async_tokio", "plotters", "html_reports"] }
curve25519-dalek = "4.1.1"
dashmap = "5.4"
delegate = "0.10.0"
dhat = "0.3.2"
embed-doc-image = "0.1.4"
futures = "0.3.28"
Expand Down
34 changes: 0 additions & 34 deletions src/helpers/buffers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,37 +5,3 @@ mod unordered_receiver;
pub use ordering_mpsc::{ordering_mpsc, OrderingMpscReceiver, OrderingMpscSender};
pub use ordering_sender::{OrderedStream, OrderingSender};
pub use unordered_receiver::UnorderedReceiver;

#[cfg(debug_assertions)]
#[allow(unused)] // todo(alex): make test world print the state again
mod waiting {
use std::collections::HashMap;

use crate::helpers::ChannelId;

pub(in crate::helpers) struct WaitingTasks<'a> {
tasks: HashMap<&'a ChannelId, Vec<u32>>,
}

impl<'a> WaitingTasks<'a> {
pub fn new(tasks: HashMap<&'a ChannelId, Vec<u32>>) -> Self {
Self { tasks }
}

pub fn is_empty(&self) -> bool {
self.tasks.is_empty()
}
}

impl std::fmt::Debug for WaitingTasks<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "[")?;
for (channel, records) in &self.tasks {
write!(f, "\n {channel:?}: {records:?}")?;
}
write!(f, "\n]")?;

Ok(())
}
}
}
23 changes: 23 additions & 0 deletions src/helpers/buffers/ordering_sender.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,11 @@ impl WaitingShard {
self.wakers.pop_front().unwrap().w.wake();
}
}

#[cfg(feature = "stall-detection")]
pub fn waiting(&self) -> impl Iterator<Item = usize> + '_ {
self.wakers.iter().map(|waker| waker.i)
}
}

/// A collection of wakers that are indexed by the send index (`i`).
Expand Down Expand Up @@ -224,6 +229,19 @@ impl Waiting {
fn wake(&self, i: usize) {
self.shard(i).wake(i);
}

/// Returns all records currently waiting to be sent in sorted order.
#[cfg(feature = "stall-detection")]
fn waiting(&self) -> Vec<usize> {
let mut records = Vec::new();
self.shards
.iter()
.for_each(|shard| records.extend(shard.lock().unwrap().waiting()));

records.sort_unstable();

records
}
}

/// An `OrderingSender` accepts messages for sending in any order, but
Expand Down Expand Up @@ -375,6 +393,11 @@ impl OrderingSender {
) -> OrderedStream<crate::sync::Arc<Self>> {
OrderedStream { sender: self }
}

#[cfg(feature = "stall-detection")]
pub fn waiting(&self) -> Vec<usize> {
self.waiting.waiting()
}
}

/// A future for writing item `i` into an `OrderingSender`.
Expand Down
71 changes: 70 additions & 1 deletion src/helpers/buffers/unordered_receiver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ where
stream: Pin<Box<S>>,
/// The absolute index of the next value that will be received.
next: usize,
/// The maximum value that has ever been requested to receive.
max_polled_idx: usize,
/// The underlying stream can provide chunks of data larger than a single
/// message. Save any spare data here.
spare: Spare,
Expand Down Expand Up @@ -143,6 +145,12 @@ where
/// Note: in protocols we try to send before receiving, so we can rely on
/// that easing load on this mechanism. There might also need to be some
/// end-to-end back pressure for tasks that do not involve sending at all.
///
/// If stall detection is enabled, the index of that waker is stored alongside with it, in order
/// to correctly identify the `i` awaiting completion
#[cfg(feature = "stall-detection")]
overflow_wakers: Vec<(Waker, usize)>,
#[cfg(not(feature = "stall-detection"))]
overflow_wakers: Vec<Waker>,
_marker: PhantomData<C>,
}
Expand Down Expand Up @@ -172,7 +180,11 @@ where
);
// We don't save a waker at `self.next`, so `>` and not `>=`.
if i > self.next + self.wakers.len() {
self.overflow_wakers.push(waker);
#[cfg(feature = "stall-detection")]
let overflow = (waker, i);
#[cfg(not(feature = "stall-detection"))]
let overflow = waker;
self.overflow_wakers.push(overflow);
} else {
let index = i % self.wakers.len();
if let Some(old) = self.wakers[index].as_ref() {
Expand All @@ -195,6 +207,11 @@ where
}
if self.next % (self.wakers.len() / 2) == 0 {
// Wake all the overflowed wakers. See comments on `overflow_wakers`.
#[cfg(feature = "stall-detection")]
for (w, _) in take(&mut self.overflow_wakers) {
w.wake();
}
#[cfg(not(feature = "stall-detection"))]
for w in take(&mut self.overflow_wakers) {
w.wake();
}
Expand All @@ -204,6 +221,7 @@ where
/// Poll for the next record. This should only be invoked when
/// the future for the next message is polled.
fn poll_next<M: Message>(&mut self, cx: &mut Context<'_>) -> Poll<Result<M, Error>> {
self.max_polled_idx = std::cmp::max(self.max_polled_idx, self.next);
if let Some(m) = self.spare.read() {
self.wake_next();
return Poll::Ready(Ok(m));
Expand All @@ -228,6 +246,46 @@ where
}
}
}

#[cfg(feature = "stall-detection")]
fn waiting(&self) -> impl Iterator<Item = usize> + '_ {
/// There is no waker for self.next and it could be advanced past the end of the stream.
/// This helps to conditionally add self.next to the waiting list.
struct MaybeNext {
currently_at: usize,
next: usize,
}
impl Iterator for MaybeNext {
type Item = usize;

fn next(&mut self) -> Option<Self::Item> {
if self.currently_at == self.next {
self.currently_at += 1;
Some(self.next)
} else {
None
}
}
}

let start = self.next % self.wakers.len();
self.wakers
.iter()
.enumerate()
.filter_map(|(i, waker)| waker.as_ref().map(|_| i))
.map(move |i| {
if i < start {
self.next + (self.wakers.len() - start + i)
} else {
self.next + (i - start)
}
})
.chain(self.overflow_wakers.iter().map(|v| v.1))
.chain(MaybeNext {
currently_at: self.max_polled_idx,
next: self.next,
})
}
}

/// Take an ordered stream of bytes and make messages from that stream
Expand Down Expand Up @@ -262,6 +320,7 @@ where
inner: Arc::new(Mutex::new(OperatingState {
stream,
next: 0,
max_polled_idx: 0,
spare: Spare::default(),
wakers,
overflow_wakers: Vec::new(),
Expand All @@ -284,6 +343,16 @@ where
_marker: PhantomData,
}
}

#[cfg(feature = "stall-detection")]
pub fn waiting(&self) -> Vec<usize> {
let state = self.inner.lock().unwrap();
let mut r = state.waiting().collect::<Vec<_>>();

r.sort_unstable();

r
}
}

impl<S, C> Clone for UnorderedReceiver<S, C>
Expand Down
79 changes: 52 additions & 27 deletions src/helpers/gateway/mod.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
mod receive;
mod send;
#[cfg(feature = "stall-detection")]
pub(super) mod stall_detection;
mod transport;

use std::{fmt::Debug, num::NonZeroUsize};
use std::num::NonZeroUsize;

pub use send::SendingEnd;
#[cfg(all(feature = "shuttle", test))]
pub(super) use receive::ReceivingEnd;
pub(super) use send::SendingEnd;
#[cfg(all(test, feature = "shuttle"))]
use shuttle::future as tokio;
#[cfg(feature = "stall-detection")]
pub(super) use stall_detection::InstrumentedGateway;

use crate::{
helpers::{
gateway::{
receive::{GatewayReceivers, ReceivingEnd as ReceivingEndBase},
send::GatewaySenders,
transport::RoleResolvingTransport,
receive::GatewayReceivers, send::GatewaySenders, transport::RoleResolvingTransport,
},
ChannelId, Message, Role, RoleAssignment, TotalRecords, Transport,
},
Expand All @@ -31,35 +34,45 @@ pub type TransportImpl = super::transport::InMemoryTransport;
pub type TransportImpl = crate::sync::Arc<crate::net::HttpTransport>;

pub type TransportError = <TransportImpl as Transport>::Error;
pub type ReceivingEnd<M> = ReceivingEndBase<TransportImpl, M>;

/// Gateway into IPA Infrastructure systems. This object allows sending and receiving messages.
/// As it is generic over network/transport layer implementation, type alias [`Gateway`] should be
/// used to avoid carrying `T` over.
///
/// [`Gateway`]: crate::helpers::Gateway
pub struct Gateway<T: Transport = TransportImpl> {
/// Gateway into IPA Network infrastructure. It allows helpers send and receive messages.
pub struct Gateway {
config: GatewayConfig,
transport: RoleResolvingTransport<T>,
transport: RoleResolvingTransport,
#[cfg(feature = "stall-detection")]
inner: crate::sync::Arc<State>,
#[cfg(not(feature = "stall-detection"))]
inner: State,
}

#[derive(Default)]
pub struct State {
senders: GatewaySenders,
receivers: GatewayReceivers<T>,
receivers: GatewayReceivers,
}

#[derive(Clone, Copy, Debug)]
pub struct GatewayConfig {
/// The number of items that can be active at the one time.
/// This is used to determine the size of sending and receiving buffers.
active: NonZeroUsize,

/// Time to wait before checking gateway progress. If no progress has been made between
/// checks, the gateway is considered to be stalled and will create a report with outstanding
/// send/receive requests
#[cfg(feature = "stall-detection")]
pub progress_check_interval: std::time::Duration,
}

impl<T: Transport> Gateway<T> {
impl Gateway {
#[must_use]
pub fn new(
query_id: QueryId,
config: GatewayConfig,
roles: RoleAssignment,
transport: T,
transport: TransportImpl,
) -> Self {
#[allow(clippy::useless_conversion)] // not useless in stall-detection build
Self {
config,
transport: RoleResolvingTransport {
Expand All @@ -68,8 +81,7 @@ impl<T: Transport> Gateway<T> {
inner: transport,
config,
},
senders: GatewaySenders::default(),
receivers: GatewayReceivers::default(),
inner: State::default().into(),
}
}

Expand All @@ -91,10 +103,12 @@ impl<T: Transport> Gateway<T> {
&self,
channel_id: &ChannelId,
total_records: TotalRecords,
) -> SendingEnd<M> {
let (tx, maybe_stream) =
self.senders
.get_or_create::<M>(channel_id, self.config.active_work(), total_records);
) -> send::SendingEnd<M> {
let (tx, maybe_stream) = self.inner.senders.get_or_create::<M>(
channel_id,
self.config.active_work(),
total_records,
);
if let Some(stream) = maybe_stream {
tokio::spawn({
let channel_id = channel_id.clone();
Expand All @@ -109,14 +123,15 @@ impl<T: Transport> Gateway<T> {
});
}

SendingEnd::new(tx, self.role(), channel_id)
send::SendingEnd::new(tx, self.role(), channel_id)
}

#[must_use]
pub fn get_receiver<M: Message>(&self, channel_id: &ChannelId) -> ReceivingEndBase<T, M> {
ReceivingEndBase::new(
pub fn get_receiver<M: Message>(&self, channel_id: &ChannelId) -> receive::ReceivingEnd<M> {
receive::ReceivingEnd::new(
channel_id.clone(),
self.receivers
self.inner
.receivers
.get_or_create(channel_id, || self.transport.receive(channel_id)),
)
}
Expand All @@ -135,8 +150,18 @@ impl GatewayConfig {
/// If `active` is 0.
#[must_use]
pub fn new(active: usize) -> Self {
// In-memory tests are fast, so progress check intervals can be lower.
// Real world scenarios currently over-report stalls because of inefficiencies inside
// infrastructure and actual networking issues. This checks is only valuable to report
// bugs, so keeping it large enough to avoid false positives.
Self {
active: NonZeroUsize::new(active).unwrap(),
#[cfg(feature = "stall-detection")]
progress_check_interval: std::time::Duration::from_secs(if cfg!(test) {
5
} else {
30
}),
}
}

Expand Down
Loading
Loading