Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix send buffer misalignment issue #1307

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 122 additions & 19 deletions ipa-core/src/helpers/gateway/send.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::{
borrow::Borrow,
cmp::min,
fmt::Debug,
marker::PhantomData,
num::NonZeroUsize,
Expand Down Expand Up @@ -249,35 +250,79 @@ impl<I: Debug> Stream for GatewaySendStream<I> {

impl SendChannelConfig {
fn new<M: Message>(gateway_config: GatewayConfig, total_records: TotalRecords) -> Self {
debug_assert!(M::Size::USIZE > 0, "Message size cannot be 0");
Self::new_with(gateway_config, total_records, M::Size::USIZE)
}

let record_size = M::Size::USIZE;
let total_capacity = gateway_config.active.get() * record_size;
Self {
total_capacity: total_capacity.try_into().unwrap(),
record_size: record_size.try_into().unwrap(),
read_size: if total_records.is_indeterminate()
|| gateway_config.read_size.get() <= record_size
{
fn new_with(
gateway_config: GatewayConfig,
total_records: TotalRecords,
record_size: usize,
) -> Self {
debug_assert!(record_size > 0, "Message size cannot be 0");
// The absolute minimum of capacity we reserve for this channel. We can't go
// below that number, otherwise a deadlock is almost guaranteed.
let min_capacity = gateway_config.active.get() * record_size;

// first, compute the read size. It must be a multiple of `record_size` to prevent
// misaligned reads and deadlocks. For indeterminate channels, read size must be
// set to the size of one record, to trigger buffer flush on every write
let read_size =
if total_records.is_indeterminate() || gateway_config.read_size.get() <= record_size {
record_size
} else {
std::cmp::min(
total_capacity,
// closest multiple of record_size to read_size
// closest multiple of record_size to read_size
let proposed_read_size = min(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure this isn't supposed to be max?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it probably shouldn't. If read_size goes above capacity, then we will never read anything from that buffer because even when it is 100% full, read_size is still larger, so HTTP will back off

gateway_config.read_size.get() / record_size * record_size,
)
}
.try_into()
.unwrap(),
min_capacity,
);
// if min capacity is not a multiple of read size.
// we must adjust read size. Adjusting total capacity is not possible due to
// possible deadlocks - it must be strictly aligned with active work.
// read size goes in `record_size` increments to keep it aligned.
// rem is aligned with both capacity and read_size, so subtracting
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is "rem"?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be record_size

// it will keep read_size and capacity aligned
// Here is an example how it may work:
// lets say the active work is set to 10, record size is 512 bytes
// and read size in gateway config is set to 2048 bytes (default value).
// the math above will compute total_capacity to 5120 bytes and
// proposed_read_size to 2048 because it is aligned with 512 record size.
// Now, if we don't adjust then we have an issue as 5120 % 2048 = 1024 != 0.
// Keeping read size like this will cause a deadlock, so we adjust it to
// 1024.
proposed_read_size - min_capacity % proposed_read_size
};

// total capacity must be a multiple of both read size and record size.
// Record size is easy to justify: misalignment here leads to either waste of memory
// or deadlock on the last write. Aligning read size and total capacity
// has the same reasoning behind it: reading less than total capacity
// can leave the last chunk half-written and backpressure from active work
// preventing the protocols to make further progress.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is too complicated. I worry that this code will become unmaintainable as this is so impenetrable that people simply will not understand it. I certainly don't understand it and we discussed it in person today and I've read this twice.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe the explanation is overly complicated here. All we need to say here is that all 3 parameters need to be aligned. read_size is a multiple of record_size and total_capacity is a multiple of read_size

let total_capacity = min_capacity / read_size * read_size;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

record_size = 128
active_work = 33
min_capacity = 4224
proposed_read_size = 2048
read_size = 2048 - 128 = 1920
total_capacity = 3840


let this = Self {
total_capacity: total_capacity.try_into().unwrap(),
record_size: record_size.try_into().unwrap(),
read_size: read_size.try_into().unwrap(),
total_records,
}
};

// make sure we've set these values correctly.
debug_assert_eq!(0, this.total_capacity.get() % this.read_size.get());
debug_assert_eq!(0, this.total_capacity.get() % this.record_size.get());
debug_assert!(this.total_capacity.get() >= this.read_size.get());
debug_assert!(this.total_capacity.get() >= this.record_size.get());
debug_assert!(this.read_size.get() >= this.record_size.get());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on the comments above, I think you also want to check that:

debug_assert_eq!(0, this.read_size.get() % this.record_size.get());

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And there is another condition that the comments seem to suggest we should check, which is:

debug_assert!(this.total_capacity.get() >= gateway_config.active.get() * record_size);

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the addition of these two checks, I think this set of checks are equivalent to:

  1. read_size = a * record_size
  2. total_capacity = b * read_size
  3. a * b >= active_work

Where a and b are positive integers >= 1.

Can we just find a and b like this:

if total_records.is_indeterminate() || gateway_config.read_size.get() <= record_size {
    a = 1;
    b = active_work;
} else {
    a = gateway_config.read_size.get() / record_size;
    b = ceil(active_work / a);
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

record_size = 128
active_work = 33
a = 2048 / 128 = 16
b = ceil(33 / 16) = 3
read_size = a * record_size = 2048
total_capacity = b * read_size = 6144

The read size (2048) does not divide the active_work-determined capacity (4224), so this will hang.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, any prime number for active work makes this broken


this
}
}

#[cfg(test)]
mod test {
use std::num::NonZeroUsize;

use proptest::proptest;
use typenum::Unsigned;

use crate::{
Expand All @@ -286,7 +331,7 @@ mod test {
Serializable,
},
helpers::{gateway::send::SendChannelConfig, GatewayConfig, TotalRecords},
secret_sharing::SharedValue,
secret_sharing::{Sendable, StdArray},
};

impl Default for SendChannelConfig {
Expand All @@ -301,7 +346,7 @@ mod test {
}

#[allow(clippy::needless_update)] // to allow progress_check_interval to be conditionally compiled
fn send_config<V: SharedValue, const A: usize, const R: usize>(
fn send_config<V: Sendable, const A: usize, const R: usize>(
total_records: TotalRecords,
) -> SendChannelConfig {
let gateway_config = GatewayConfig {
Expand Down Expand Up @@ -391,4 +436,62 @@ mod test {
.get()
);
}

/// This test reproduces ipa/#1300. PRF evaluation sent 32*16 = 512 (`record_size` * vectorization)
/// chunks through a channel with total capacity 5120 (active work = 10 records) and read size
/// of 2048 bytes.
/// The problem was that read size of 2048 does not divide 5120, so the last chunk was not sent.
#[test]
fn total_capacity_is_a_multiple_of_read_size() {
let config =
send_config::<StdArray<BA256, 16>, 10, 2048>(TotalRecords::specified(43).unwrap());

assert_eq!(0, config.total_capacity.get() % config.read_size.get());
assert_eq!(config.total_capacity.get(), 10 * config.record_size.get());
Comment on lines +448 to +451
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the number 10 on line 447 and line 450 are meant to be kept in sync, can you use a constant for them?

}

fn ensure_config(
total_records: Option<usize>,
active: usize,
read_size: usize,
record_size: usize,
) {
let gateway_config = GatewayConfig {
active: active.try_into().unwrap(),
read_size: read_size.try_into().unwrap(),
..Default::default()
};
let config = SendChannelConfig::new_with(
gateway_config,
total_records.map_or(TotalRecords::Indeterminate, |v| {
TotalRecords::specified(v).unwrap()
}),
record_size,
);

// total capacity checks
assert!(config.total_capacity.get() > 0);
assert!(config.total_capacity.get() >= record_size);
assert!(config.total_capacity.get() <= record_size * active);
assert!(config.total_capacity.get() >= config.read_size.get());
assert_eq!(0, config.total_capacity.get() % config.record_size.get());

// read size checks
assert!(config.read_size.get() > 0);
assert!(config.read_size.get() >= config.record_size.get());
assert_eq!(0, config.total_capacity.get() % config.read_size.get());
assert_eq!(0, config.read_size.get() % config.record_size.get());
Comment on lines +473 to +484
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are duplicated between the test and the actual code. I don't think there is a need for this duplication. Just leaving them in the function seems sufficient.

}

proptest! {
#[test]
fn config_prop(
total_records in proptest::option::of(1_usize..1 << 32),
active in 1_usize..100_000,
read_size in 1_usize..32768,
record_size in 1_usize..4096,
) {
ensure_config(total_records, active, read_size, record_size);
}
}
}
11 changes: 7 additions & 4 deletions ipa-core/tests/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,12 @@ pub fn test_network<T: NetworkTest>(https: bool) {
T::execute(path, https);
}

pub fn test_ipa(mode: IpaSecurityModel, https: bool, encrypted_inputs: bool) {
test_ipa_with_config(
pub fn test_ipa<const INPUT_SIZE: usize>(
mode: IpaSecurityModel,
https: bool,
encrypted_inputs: bool,
) {
test_ipa_with_config::<INPUT_SIZE>(
mode,
https,
IpaQueryConfig {
Expand All @@ -228,7 +232,7 @@ pub fn test_ipa(mode: IpaSecurityModel, https: bool, encrypted_inputs: bool) {
);
}

pub fn test_ipa_with_config(
pub fn test_ipa_with_config<const INPUT_SIZE: usize>(
mode: IpaSecurityModel,
https: bool,
config: IpaQueryConfig,
Expand All @@ -238,7 +242,6 @@ pub fn test_ipa_with_config(
panic!("encrypted_input requires https")
};

const INPUT_SIZE: usize = 100;
// set to true to always keep the temp dir after test finishes
let dir = TempDir::new_delete_on_drop();
let path = dir.path();
Expand Down
2 changes: 1 addition & 1 deletion ipa-core/tests/compact_gate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ fn test_compact_gate<I: TryInto<NonZeroU32>>(

// test https with encrypted input
// and http with plaintest input
test_ipa_with_config(mode, encrypted_input, config, encrypted_input);
test_ipa_with_config::<100>(mode, encrypted_input, config, encrypted_input);
}

#[test]
Expand Down
6 changes: 3 additions & 3 deletions ipa-core/tests/helper_networks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,20 +45,20 @@ fn http_network_large_input() {
#[test]
#[cfg(all(test, web_test))]
fn http_semi_honest_ipa() {
test_ipa(IpaSecurityModel::SemiHonest, false, false);
test_ipa::<100>(IpaSecurityModel::SemiHonest, false, false);
}

#[test]
#[cfg(all(test, web_test))]
fn https_semi_honest_ipa() {
test_ipa(IpaSecurityModel::SemiHonest, true, true);
test_ipa::<100>(IpaSecurityModel::SemiHonest, true, true);
}

#[test]
#[cfg(all(test, web_test))]
#[ignore]
fn https_malicious_ipa() {
test_ipa(IpaSecurityModel::Malicious, true, true);
test_ipa::<100>(IpaSecurityModel::Malicious, true, true);
}

/// Similar to [`network`] tests, but it uses keygen + confgen CLIs to generate helper client config
Expand Down
12 changes: 9 additions & 3 deletions ipa-core/tests/ipa_with_relaxed_dp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ fn relaxed_dp_semi_honest() {
let encrypted_input = false;
let config = build_config();

test_ipa_with_config(
test_ipa_with_config::<100>(
IpaSecurityModel::SemiHonest,
encrypted_input,
config,
Expand All @@ -33,7 +33,7 @@ fn relaxed_dp_malicious() {
let encrypted_input = false;
let config = build_config();

test_ipa_with_config(
test_ipa_with_config::<100>(
IpaSecurityModel::Malicious,
encrypted_input,
config,
Expand All @@ -44,5 +44,11 @@ fn relaxed_dp_malicious() {
#[test]
#[cfg(all(test, web_test))]
fn relaxed_dp_https_malicious_ipa() {
test_ipa(IpaSecurityModel::Malicious, true, true);
test_ipa::<100>(IpaSecurityModel::Malicious, true, true);
}

#[test]
#[cfg(all(test, web_test))]
fn relaxed_dp_https_malicious_ipa_10_rows() {
test_ipa::<10>(IpaSecurityModel::Malicious, true, true);
}
Loading