From f5909791128669198ed1cf8e0bbd2235b48f79de Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Tue, 1 Oct 2024 22:07:23 -0700 Subject: [PATCH] Enforce active work to be a power of two This is the second attempt to mitigate send buffer misalignment. Previous one (#1307) didn't handle all the edge cases and was abandoned in favour of this PR. What I believe makes this change work is the new requirement for active work to be a power of two. With this constraint, it is much easier to align the read size with it. Given that `total_capacity = active * record_size`, we can represent `read_size` as a multiple of `record_size` too: `read_size = X * record_size`. If X is a power of two and active_work is a power of two, then they will always be aligned with each other. For example, if active work is 16, read size is 10 bytes and record size is 3 bytes, then: ``` total_capacity = 16*3 read_size = X*3 (close to 10) X = 2 (power of two that satisfies the requirement) ``` when picking up the read size, we are rounding down to avoid buffer overflows. In the example above, setting X=3 would make it closer to the desired read size, but it is greater than 10, so we pick 2 instead. --- ipa-core/src/helpers/gateway/mod.rs | 211 ++++++++++++++++++++++++++- ipa-core/src/helpers/gateway/send.rs | 96 ++++++++++-- 2 files changed, 289 insertions(+), 18 deletions(-) diff --git a/ipa-core/src/helpers/gateway/mod.rs b/ipa-core/src/helpers/gateway/mod.rs index e654f85f7..690797958 100644 --- a/ipa-core/src/helpers/gateway/mod.rs +++ b/ipa-core/src/helpers/gateway/mod.rs @@ -73,6 +73,7 @@ pub struct State { pub struct GatewayConfig { /// The number of items that can be active at the one time. /// This is used to determine the size of sending and receiving buffers. + /// Any value that is not a power of two will be rejected pub active: NonZeroUsize, /// Number of bytes packed and sent together in one batch down to the network layer. This @@ -84,6 +85,10 @@ pub struct GatewayConfig { /// payload may not be exactly this, but it will be the closest multiple of record size to this /// number. For instance, having 14 bytes records and batch size of 4096 will result in /// 4088 bytes being sent in a batch. + /// + /// The actual size for read chunks may be bigger or smaller, depending on the record size + /// sent through each channel. Read size will be aligned with [`Self::active_work`] value to + /// prevent deadlocks. pub read_size: NonZeroUsize, /// Time to wait before checking gateway progress. If no progress has been made between @@ -279,7 +284,8 @@ impl GatewayConfig { // capabilities (see #ipa/1171) to allow that currently. usize::from(value.size), ), - ); + ) + .next_power_of_two(); // we set active to be at least 2, so unwrap is fine. self.active = NonZeroUsize::new(active).unwrap(); } @@ -299,23 +305,35 @@ mod tests { use std::{ iter::{repeat, zip}, num::NonZeroUsize, + sync::Arc, }; use futures::{ future::{join, try_join, try_join_all}, + stream, stream::StreamExt, }; + use proptest::proptest; use crate::{ - ff::{boolean_array::BA3, Fp31, Fp32BitPrime, Gf2, U128Conversions}, + ff::{ + boolean_array::{BA20, BA256, BA3, BA4, BA5, BA6, BA7, BA8}, + FieldType, Fp31, Fp32BitPrime, Gf2, U128Conversions, + }, helpers::{ - ChannelId, Direction, GatewayConfig, MpcMessage, Role, SendingEnd, TotalRecords, + gateway::QueryConfig, + query::{QuerySize, QueryType}, + ChannelId, Direction, GatewayConfig, MpcMessage, MpcReceivingEnd, Role, SendingEnd, + TotalRecords, }, protocol::{ context::{Context, ShardedContext}, Gate, RecordId, }, - secret_sharing::{replicated::semi_honest::AdditiveShare, SharedValue}, + secret_sharing::{ + replicated::semi_honest::AdditiveShare, SharedValue, SharedValueArray, StdArray, + }, + seq_join::seq_join, sharding::ShardConfiguration, test_executor::run, test_fixture::{Reconstruct, Runner, TestWorld, TestWorldConfig, WithShards}, @@ -569,6 +587,87 @@ mod tests { }); } + macro_rules! send_recv_test { + ( + message: $message:expr, + read_size: $read_size:expr, + active_work: $active_work:expr, + total_records: $total_records:expr, + $test_fn: ident + ) => { + #[test] + fn $test_fn() { + run(|| async { + send_recv($read_size, $active_work, $total_records, $message).await; + }); + } + }; + } + + send_recv_test! { + message: BA20::ZERO, + read_size: 5, + active_work: 8, + total_records: 25, + test_ba20_5_10_25 + } + + send_recv_test! { + message: StdArray::::ZERO_ARRAY, + read_size: 2048, + active_work: 16, + total_records: 43, + test_ba256_by_16_2048_10_43 + } + + send_recv_test! { + message: StdArray::::ZERO_ARRAY, + read_size: 2048, + active_work: 32, + total_records: 50, + test_ba8_by_16_2048_37_50 + } + + proptest! { + #[test] + fn send_recv_randomized( + total_records in 1_usize..10_000, + active in 1_usize..10_000, + read_size in (1_usize..32768), + record_size in 1_usize..=8, + ) { + let active = active.next_power_of_two(); + run(move || async move { + match record_size { + 1 => send_recv(read_size, active, total_records, StdArray::::ZERO_ARRAY).await, + 2 => send_recv(read_size, active, total_records, StdArray::::ZERO_ARRAY).await, + 3 => send_recv(read_size, active, total_records, BA3::ZERO).await, + 4 => send_recv(read_size, active, total_records, BA4::ZERO).await, + 5 => send_recv(read_size, active, total_records, BA5::ZERO).await, + 6 => send_recv(read_size, active, total_records, BA6::ZERO).await, + 7 => send_recv(read_size, active, total_records, BA7::ZERO).await, + 8 => send_recv(read_size, active, total_records, StdArray::::ZERO_ARRAY).await, + _ => unreachable!(), + } + }); + } + } + + /// ensures when active work is set from query input, it is always a power of two + #[test] + fn gateway_config_active_work_power_of_two() { + let mut config = GatewayConfig { + active: 2.try_into().unwrap(), + ..Default::default() + }; + config.set_active_work_from_query_config(&QueryConfig { + size: QuerySize::try_from(5).unwrap(), + field_type: FieldType::Fp31, + query_type: QueryType::TestAddInPrimeField, + }); + assert_eq!(8, config.active_work().get()); + } + async fn shard_comms_test(test_world: &TestWorld>) { let input = vec![BA3::truncate_from(0_u32), BA3::truncate_from(1_u32)]; @@ -606,4 +705,108 @@ mod tests { let world_ptr = world as *mut _; (world, world_ptr) } + + /// This serves the purpose of randomized testing of our send channels by providing + /// variable sizes for read size, active work and record size + async fn send_recv(read_size: usize, active_work: usize, total_records: usize, sample: M) + where + M: MpcMessage + Clone + PartialEq, + { + fn duplex_channel( + world: &TestWorld, + left: Role, + right: Role, + total_records: usize, + active_work: usize, + ) -> (SendingEnd, MpcReceivingEnd) { + ( + world.gateway(left).get_mpc_sender::( + &ChannelId::new(right, Gate::default()), + TotalRecords::specified(total_records).unwrap(), + active_work.try_into().unwrap(), + ), + world + .gateway(right) + .get_mpc_receiver::(&ChannelId::new(left, Gate::default())), + ) + } + + async fn circuit( + send_channel: SendingEnd, + recv_channel: MpcReceivingEnd, + active_work: usize, + total_records: usize, + msg: M, + ) where + M: MpcMessage + Clone + PartialEq, + { + let send_notify = Arc::new(tokio::sync::Notify::new()); + + // perform "multiplication-like" operation (send + subsequent receive) + // and "validate": block the future until we have at least `active_work` + // futures pending and unblock them all at the same time + seq_join( + active_work.try_into().unwrap(), + stream::iter(std::iter::repeat(msg).take(total_records).enumerate()).map( + |(record_id, msg)| { + let send_channel = &send_channel; + let recv_channel = &recv_channel; + let send_notify = Arc::clone(&send_notify); + async move { + send_channel + .send(record_id.into(), msg.clone()) + .await + .unwrap(); + let r = recv_channel.receive(record_id.into()).await.unwrap(); + // this simulates validate_record API by forcing futures to wait + // until the entire batch is validated by the last future in that batch + if record_id % active_work == active_work - 1 + || record_id == total_records - 1 + { + send_notify.notify_waiters(); + } else { + send_notify.notified().await; + } + assert_eq!(msg, r); + } + }, + ), + ) + .collect::>() + .await; + } + + let config = TestWorldConfig { + gateway_config: GatewayConfig { + active: active_work.try_into().unwrap(), + read_size: read_size.try_into().unwrap(), + ..Default::default() + }, + ..Default::default() + }; + + let world = TestWorld::new_with(&config); + let (h1_send_channel, h1_recv_channel) = + duplex_channel(&world, Role::H1, Role::H2, total_records, active_work); + let (h2_send_channel, h2_recv_channel) = + duplex_channel(&world, Role::H2, Role::H1, total_records, active_work); + + join( + circuit( + h1_send_channel, + h1_recv_channel, + active_work, + total_records, + sample.clone(), + ), + circuit( + h2_send_channel, + h2_recv_channel, + active_work, + total_records, + sample, + ), + ) + .await; + } } diff --git a/ipa-core/src/helpers/gateway/send.rs b/ipa-core/src/helpers/gateway/send.rs index 07018fb14..e75cac7b2 100644 --- a/ipa-core/src/helpers/gateway/send.rs +++ b/ipa-core/src/helpers/gateway/send.rs @@ -248,28 +248,43 @@ impl Stream for GatewaySendStream { impl SendChannelConfig { fn new(gateway_config: GatewayConfig, total_records: TotalRecords) -> Self { - debug_assert!(M::Size::USIZE > 0, "Message size cannot be 0"); + Self::new_with(gateway_config, total_records, M::Size::USIZE) + } + fn new_with( + gateway_config: GatewayConfig, + total_records: TotalRecords, + record_size: usize, + ) -> Self { + debug_assert!(record_size > 0, "Message size cannot be 0"); + debug_assert!( + gateway_config.active.is_power_of_two(), + "Active work {} must be a power of two", + gateway_config.active.get() + ); - let record_size = M::Size::USIZE; let total_capacity = gateway_config.active.get() * record_size; - Self { + // define read size in terms of percentage of active work, rather than bytes. + // both are powers of two, so it should always be possible. We pick the read size + // to be the closest to the configuration value in bytes. + // let read_size = closest_multiple(record_size, gateway_config.read_size.get()); + let read_size = (gateway_config.read_size.get() / record_size + 1).next_power_of_two() / 2 + * record_size; + let this = Self { total_capacity: total_capacity.try_into().unwrap(), record_size: record_size.try_into().unwrap(), - read_size: if total_records.is_indeterminate() - || gateway_config.read_size.get() <= record_size - { + read_size: if total_records.is_indeterminate() || read_size <= record_size { record_size } else { - std::cmp::min( - total_capacity, - // closest multiple of record_size to read_size - gateway_config.read_size.get() / record_size * record_size, - ) + std::cmp::min(total_capacity, read_size) } .try_into() .unwrap(), total_records, - } + }; + + debug_assert!(this.total_capacity.get() >= record_size * gateway_config.active.get()); + + this } } @@ -277,6 +292,7 @@ impl SendChannelConfig { mod test { use std::num::NonZeroUsize; + use proptest::proptest; use typenum::Unsigned; use crate::{ @@ -379,15 +395,67 @@ mod test { fn config_read_size_closest_multiple_to_record_size() { assert_eq!( 6, - send_config::(TotalRecords::Specified(2.try_into().unwrap())) + send_config::(TotalRecords::Specified(2.try_into().unwrap())) .read_size .get() ); assert_eq!( 6, - send_config::(TotalRecords::Specified(2.try_into().unwrap())) + send_config::(TotalRecords::Specified(2.try_into().unwrap())) .read_size .get() ); } + + #[test] + fn config_read_size_record_size_misalignment() { + ensure_config(Some(15), 90, 16, 3); + } + + fn ensure_config( + total_records: Option, + active: usize, + read_size: usize, + record_size: usize, + ) { + let gateway_config = GatewayConfig { + active: active.next_power_of_two().try_into().unwrap(), + read_size: read_size.try_into().unwrap(), + // read_size: read_size.next_power_of_two().try_into().unwrap(), + ..Default::default() + }; + let config = SendChannelConfig::new_with( + gateway_config, + total_records.map_or(TotalRecords::Indeterminate, |v| { + TotalRecords::specified(v).unwrap() + }), + record_size, + ); + + // total capacity checks + assert!(config.total_capacity.get() > 0); + assert!(config.total_capacity.get() >= config.read_size.get()); + assert_eq!(0, config.total_capacity.get() % config.record_size.get()); + assert_eq!( + config.total_capacity.get(), + record_size * gateway_config.active.get() + ); + + // read size checks + assert!(config.read_size.get() > 0); + assert!(config.read_size.get() >= config.record_size.get()); + assert_eq!(0, config.total_capacity.get() % config.read_size.get()); + } + + proptest! { + #[test] + fn config_prop( + total_records in proptest::option::of(1_usize..1 << 32), + active in 1_usize..100_000, + read_size in 1_usize..32768, + record_size in 1_usize..4096, + ) { + ensure_config(total_records, active, read_size, record_size); + } + } }