Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix send buffer misalignment [No 2] #1332

Merged
1 change: 1 addition & 0 deletions ipa-core/benches/oneshot/ipa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ impl Args {
self.active_work
.map(NonZeroUsize::get)
.unwrap_or_else(|| self.query_size.clamp(16, 1024))
.next_power_of_two()
}

fn attribution_window(&self) -> Option<NonZeroU32> {
Expand Down
7 changes: 4 additions & 3 deletions ipa-core/src/app.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{num::NonZeroUsize, sync::Weak};
use std::sync::Weak;

use async_trait::async_trait;

Expand All @@ -13,17 +13,18 @@
protocol::QueryId,
query::{NewQueryError, QueryProcessor, QueryStatus},
sync::Arc,
utils::NonZeroU32PowerOfTwo,
};

#[derive(Default)]
pub struct AppConfig {
active_work: Option<NonZeroUsize>,
active_work: Option<NonZeroU32PowerOfTwo>,
key_registry: Option<KeyRegistry<PrivateKeyOnly>>,
}

impl AppConfig {
#[must_use]
pub fn with_active_work(mut self, active_work: Option<NonZeroUsize>) -> Self {
pub fn with_active_work(mut self, active_work: Option<NonZeroU32PowerOfTwo>) -> Self {

Check warning on line 27 in ipa-core/src/app.rs

View check run for this annotation

Codecov / codecov/patch

ipa-core/src/app.rs#L27

Added line #L27 was not covered by tests
self.active_work = active_work;
self
}
Expand Down
5 changes: 2 additions & 3 deletions ipa-core/src/bin/helper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use std::{
fs,
io::BufReader,
net::TcpListener,
num::NonZeroUsize,
os::fd::{FromRawFd, RawFd},
path::{Path, PathBuf},
process,
Expand All @@ -18,7 +17,7 @@ use ipa_core::{
error::BoxError,
helpers::HelperIdentity,
net::{ClientIdentity, HttpShardTransport, HttpTransport, MpcHelperClient},
AppConfig, AppSetup,
AppConfig, AppSetup, NonZeroU32PowerOfTwo,
};
use tracing::{error, info};

Expand Down Expand Up @@ -93,7 +92,7 @@ struct ServerArgs {

/// Override the amount of active work processed in parallel
#[arg(long)]
active_work: Option<NonZeroUsize>,
active_work: Option<NonZeroU32PowerOfTwo>,
}

#[derive(Debug, Subcommand)]
Expand Down
238 changes: 226 additions & 12 deletions ipa-core/src/helpers/gateway/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
protocol::QueryId,
sharding::ShardIndex,
sync::{Arc, Mutex},
utils::NonZeroU32PowerOfTwo,
};

/// Alias for the currently configured transport.
Expand Down Expand Up @@ -73,7 +74,7 @@
pub struct GatewayConfig {
/// The number of items that can be active at the one time.
/// This is used to determine the size of sending and receiving buffers.
pub active: NonZeroUsize,
pub active: NonZeroU32PowerOfTwo,

/// Number of bytes packed and sent together in one batch down to the network layer. This
/// shouldn't be too small to keep the network throughput, but setting it large enough may
Expand All @@ -84,6 +85,10 @@
/// payload may not be exactly this, but it will be the closest multiple of record size to this
/// number. For instance, having 14 bytes records and batch size of 4096 will result in
/// 4088 bytes being sent in a batch.
///
/// The actual size for read chunks may be bigger or smaller, depending on the record size
/// sent through each channel. Read size will be aligned with [`Self::active_work`] value to
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment explaining what to do if we want to align read size perfectly with the target read size

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest rewriting this entire comment. It feels out of date.

  • line 85 has a typo "side" instead of "size".
  • line 86 refers to "batch size" but this comment is above "read_size".
  • I think the example is wrong, because 4088 bytes is not a power of two multiple of 14. If I'm not wrong, the best you can do is now 14 * 2^8 < 4096 (it's 3,584).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are absolutely right. Just did that

/// prevent deadlocks.
pub read_size: NonZeroUsize,

/// Time to wait before checking gateway progress. If no progress has been made between
Expand Down Expand Up @@ -150,7 +155,7 @@
&self,
channel_id: &HelperChannelId,
total_records: TotalRecords,
active_work: NonZeroUsize,
active_work: NonZeroU32PowerOfTwo,
) -> send::SendingEnd<Role, M> {
let transport = &self.transports.mpc;
let channel = self.inner.mpc_senders.get::<M, _>(
Expand Down Expand Up @@ -260,6 +265,11 @@
/// The configured amount of active work.
#[must_use]
pub fn active_work(&self) -> NonZeroUsize {
self.active.to_non_zero_usize()
}

#[must_use]
pub fn active_work_as_power_of_two(&self) -> NonZeroU32PowerOfTwo {
self.active
}

Expand All @@ -279,14 +289,15 @@
// capabilities (see #ipa/1171) to allow that currently.
usize::from(value.size),
),
);
)
.next_power_of_two();
// we set active to be at least 2, so unwrap is fine.
self.active = NonZeroUsize::new(active).unwrap();
self.active = NonZeroU32PowerOfTwo::try_from(active).unwrap();
}

/// Creates a new configuration by overriding the value of active work.
#[must_use]
pub fn set_active_work(&self, active_work: NonZeroUsize) -> Self {
pub fn set_active_work(&self, active_work: NonZeroU32PowerOfTwo) -> Self {
Self {
active: active_work,
..*self
Expand All @@ -298,27 +309,39 @@
mod tests {
use std::{
iter::{repeat, zip},
num::NonZeroUsize,
sync::Arc,
};

use futures::{
future::{join, try_join, try_join_all},
stream,
stream::StreamExt,
};
use proptest::proptest;

use crate::{
ff::{boolean_array::BA3, Fp31, Fp32BitPrime, Gf2, U128Conversions},
ff::{
boolean_array::{BA20, BA256, BA3, BA4, BA5, BA6, BA7, BA8},
FieldType, Fp31, Fp32BitPrime, Gf2, U128Conversions,
},
helpers::{
ChannelId, Direction, GatewayConfig, MpcMessage, Role, SendingEnd, TotalRecords,
gateway::QueryConfig,
query::{QuerySize, QueryType},
ChannelId, Direction, GatewayConfig, MpcMessage, MpcReceivingEnd, Role, SendingEnd,
TotalRecords,
},
protocol::{
context::{Context, ShardedContext},
Gate, RecordId,
},
secret_sharing::{replicated::semi_honest::AdditiveShare, SharedValue},
secret_sharing::{
replicated::semi_honest::AdditiveShare, SharedValue, SharedValueArray, StdArray,
},
seq_join::seq_join,
sharding::ShardConfiguration,
test_executor::run,
test_fixture::{Reconstruct, Runner, TestWorld, TestWorldConfig, WithShards},
utils::NonZeroU32PowerOfTwo,
};

/// Verifies that [`Gateway`] send buffer capacity is adjusted to the message size.
Expand Down Expand Up @@ -538,13 +561,19 @@
run(|| async move {
let world = TestWorld::new_with(TestWorldConfig {
gateway_config: GatewayConfig {
active: 5.try_into().unwrap(),
active: 8.try_into().unwrap(),
..Default::default()
},
..Default::default()
});
let new_active_work = NonZeroUsize::new(3).unwrap();
assert!(new_active_work < world.gateway(Role::H1).config().active_work());
let new_active_work = NonZeroU32PowerOfTwo::try_from(4).unwrap();
assert!(
new_active_work
< world
.gateway(Role::H1)
.config()
.active_work_as_power_of_two()
);
let sender = world.gateway(Role::H1).get_mpc_sender::<BA3>(
&ChannelId::new(Role::H2, Gate::default()),
TotalRecords::specified(15).unwrap(),
Expand All @@ -569,6 +598,87 @@
});
}

macro_rules! send_recv_test {
(
message: $message:expr,
read_size: $read_size:expr,
active_work: $active_work:expr,
total_records: $total_records:expr,
$test_fn: ident
) => {
#[test]
fn $test_fn() {
run(|| async {
send_recv($read_size, $active_work, $total_records, $message).await;
});
}
};
}

send_recv_test! {
message: BA20::ZERO,
read_size: 5,
active_work: 8,
total_records: 25,
test_ba20_5_10_25
}

send_recv_test! {
message: StdArray::<BA256, 16>::ZERO_ARRAY,
read_size: 2048,
active_work: 16,
total_records: 43,
test_ba256_by_16_2048_10_43
}

send_recv_test! {
message: StdArray::<BA8, 16>::ZERO_ARRAY,
read_size: 2048,
active_work: 32,
total_records: 50,
test_ba8_by_16_2048_37_50
}

proptest! {
#[test]
fn send_recv_randomized(
total_records in 1_usize..500,
active in 2_usize..1000,
read_size in (1_usize..32768),
record_size in 1_usize..=8,
) {
let active = active.next_power_of_two();
run(move || async move {
match record_size {
1 => send_recv(read_size, active, total_records, StdArray::<BA8, 32>::ZERO_ARRAY).await,
2 => send_recv(read_size, active, total_records, StdArray::<BA8, 64>::ZERO_ARRAY).await,
3 => send_recv(read_size, active, total_records, BA3::ZERO).await,
4 => send_recv(read_size, active, total_records, BA4::ZERO).await,
5 => send_recv(read_size, active, total_records, BA5::ZERO).await,
6 => send_recv(read_size, active, total_records, BA6::ZERO).await,
7 => send_recv(read_size, active, total_records, BA7::ZERO).await,
8 => send_recv(read_size, active, total_records, StdArray::<BA256, 16>::ZERO_ARRAY).await,
_ => unreachable!(),

Check warning on line 661 in ipa-core/src/helpers/gateway/mod.rs

View check run for this annotation

Codecov / codecov/patch

ipa-core/src/helpers/gateway/mod.rs#L661

Added line #L661 was not covered by tests
}
});
}
}

/// ensures when active work is set from query input, it is always a power of two
#[test]
fn gateway_config_active_work_power_of_two() {
let mut config = GatewayConfig {
active: 2.try_into().unwrap(),
..Default::default()
};
config.set_active_work_from_query_config(&QueryConfig {
size: QuerySize::try_from(5).unwrap(),
field_type: FieldType::Fp31,
query_type: QueryType::TestAddInPrimeField,
});
assert_eq!(8, config.active_work().get());
}

async fn shard_comms_test(test_world: &TestWorld<WithShards<2>>) {
let input = vec![BA3::truncate_from(0_u32), BA3::truncate_from(1_u32)];

Expand Down Expand Up @@ -606,4 +716,108 @@
let world_ptr = world as *mut _;
(world, world_ptr)
}

/// This serves the purpose of randomized testing of our send channels by providing
/// variable sizes for read size, active work and record size
async fn send_recv<M>(read_size: usize, active_work: usize, total_records: usize, sample: M)
where
M: MpcMessage + Clone + PartialEq,
{
fn duplex_channel<M: MpcMessage>(
world: &TestWorld,
left: Role,
right: Role,
total_records: usize,
active_work: usize,
) -> (SendingEnd<Role, M>, MpcReceivingEnd<M>) {
(
world.gateway(left).get_mpc_sender::<M>(
&ChannelId::new(right, Gate::default()),
TotalRecords::specified(total_records).unwrap(),
active_work.try_into().unwrap(),
),
world
.gateway(right)
.get_mpc_receiver::<M>(&ChannelId::new(left, Gate::default())),
)
}

async fn circuit<M>(
send_channel: SendingEnd<Role, M>,
recv_channel: MpcReceivingEnd<M>,
active_work: usize,
total_records: usize,
msg: M,
) where
M: MpcMessage + Clone + PartialEq,
{
let send_notify = Arc::new(tokio::sync::Notify::new());

// perform "multiplication-like" operation (send + subsequent receive)
// and "validate": block the future until we have at least `active_work`
// futures pending and unblock them all at the same time
seq_join(
active_work.try_into().unwrap(),
stream::iter(std::iter::repeat(msg).take(total_records).enumerate()).map(
|(record_id, msg)| {
let send_channel = &send_channel;
let recv_channel = &recv_channel;
let send_notify = Arc::clone(&send_notify);
async move {
send_channel
.send(record_id.into(), msg.clone())
.await

Check warning on line 769 in ipa-core/src/helpers/gateway/mod.rs

View check run for this annotation

Codecov / codecov/patch

ipa-core/src/helpers/gateway/mod.rs#L769

Added line #L769 was not covered by tests
.unwrap();
let r = recv_channel.receive(record_id.into()).await.unwrap();
// this simulates validate_record API by forcing futures to wait
// until the entire batch is validated by the last future in that batch
if record_id % active_work == active_work - 1
|| record_id == total_records - 1
{
send_notify.notify_waiters();
} else {
send_notify.notified().await;
}
assert_eq!(msg, r);
}
},
),
)
.collect::<Vec<_>>()
.await;
}

let config = TestWorldConfig {
gateway_config: GatewayConfig {
active: active_work.try_into().unwrap(),
read_size: read_size.try_into().unwrap(),
..Default::default()
},
..Default::default()
};

let world = TestWorld::new_with(&config);
let (h1_send_channel, h1_recv_channel) =
duplex_channel(&world, Role::H1, Role::H2, total_records, active_work);
let (h2_send_channel, h2_recv_channel) =
duplex_channel(&world, Role::H2, Role::H1, total_records, active_work);

join(
circuit(
h1_send_channel,
h1_recv_channel,
active_work,
total_records,
sample.clone(),
),
circuit(
h2_send_channel,
h2_recv_channel,
active_work,
total_records,
sample,
),
)
.await;
}
}
Loading
Loading