Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Seed management to improve reproducibility #1380

Merged
merged 2 commits into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion DEVELOPMENT.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,36 @@ While the computation is happening, Management will call the query_status API un

MPC requires thousands of steps to be executed and coordinated across helpers. Each of these calls is represented as a single HTTP call (this becomes more relevant during MPC computation). The service uses HTTP2 multiplexing capabilities, where multiple requests are being sent and received during the same connection.

**Work in progress**
**Work in progress**

# Testing

## Randomness

Random tests can increase coverage by generating unforeseen corner cases. Even with
detailed error messages, the output of a failed random test may not provide enough
information to see what has gone wrong. When this happens, it is important to be
able to re-run the failing case with additional diagnostic output or with candidate
fixes. To make this possible, all random values used in the test should be derived
from a random seed that is logged in the output of a failed test.

Using a random generator provided by `rand::thread_rng` will not typically achieve
reproducibility. Instead, tests should obtain a random number generator by calling
`TestWorld::rng`. An example of such a test is
`breakdown_reveal::tests::semi_honest_happy_path`. To run a test with a particular seed,
pass the seed to `TestWorld::with_seed` (or `TestWorldConfig::with_seed`).

The [`shuttle`](shuttle.md) concurrency test framework provides its own random number
generator and reproducibility mechanisms. The `ipa_core::rand` module automatically
exports either symbols from the `rand` crate or the `shuttle` equivalents based on the
build configuration. If using `TestWorld::rng`, the switch to the `shuttle` RNG is
handled automatically in `TestWorld`. In tests that do not use `TestWorld`, the
`run_random` helper will automatically use the appropriate RNG, and log a seed if using
the standard RNG. An example of such a test is `ordering_sender::test::shuffle_fp31`.
To run a test with a particular seed, use `run_with_seed`.

The `proptest` framework also has its own random number generator and reproducibility
mechanisms, but the `proptest` RNG is not integrated with `TestWorld`. When using
`proptest`, it is recommended to create a random `u64` seed in the proptest-generated
inputs and pass that seed to `TestWorld::with_seed` (or `TestWorldConfig::with_seed`).
An example of such a test is `aggregate_proptest`.
12 changes: 5 additions & 7 deletions ipa-core/src/helpers/buffers/ordering_sender.rs
Original file line number Diff line number Diff line change
Expand Up @@ -531,10 +531,9 @@ mod test {
use crate::{
ff::{Fp31, Fp32BitPrime, Gf9Bit, PrimeField, Serializable, U128Conversions},
helpers::MpcMessage,
rand::thread_rng,
secret_sharing::SharedValue,
sync::Arc,
test_executor::run,
test_executor::{run, run_random},
};

fn sender<F: PrimeField>() -> Arc<OrderingSender> {
Expand Down Expand Up @@ -725,9 +724,9 @@ mod test {
}

/// Shuffle `count` indices.
pub fn shuffle_indices(count: usize) -> Vec<usize> {
pub fn shuffle_indices(count: usize, rng: &mut impl Rng) -> Vec<usize> {
let mut indices = (0..count).collect::<Vec<_>>();
indices.shuffle(&mut thread_rng());
indices.shuffle(rng);
indices
}

Expand All @@ -737,11 +736,10 @@ mod test {
const COUNT: usize = 16;
const SZ: usize = <<Fp31 as Serializable>::Size as Unsigned>::USIZE;

run(|| async {
let mut rng = thread_rng();
run_random(|mut rng| async move {
let mut values = Vec::with_capacity(COUNT);
values.resize_with(COUNT, || rng.gen::<Fp31>());
let indices = shuffle_indices(COUNT);
let indices = shuffle_indices(COUNT, &mut rng);

let sender = sender::<Fp31>();
let (_, (), output) = join3(
Expand Down
42 changes: 40 additions & 2 deletions ipa-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,8 @@ pub(crate) mod executor {
pub(crate) mod test_executor {
use std::future::Future;

use shuttle::rand::{rngs::ThreadRng, thread_rng};

pub fn run<F, Fut>(f: F)
where
F: Fn() -> Fut + Send + Sync + 'static,
Expand All @@ -277,15 +279,30 @@ pub(crate) mod test_executor {
{
shuttle::check_random(move || shuttle::future::block_on(f()), ITER);
}

pub fn run_random<F, Fut>(f: F)
where
F: Fn(ThreadRng) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()>,
{
run(move || f(thread_rng()));
}
}

#[cfg(all(test, not(feature = "shuttle")))]
pub(crate) mod test_executor {
use std::future::Future;

use rand::{rngs::StdRng, thread_rng, Rng, SeedableRng};

// These routines use `FnOnce` because it is easier than dealing with lifetimes of
// `&mut rng` borrows in futures. If there were a need to support multiple
// iterations (or to make the API use `Fn` to match the shuttle version), the
// simplest strategy might be to seed per-iteration RNGs from a primary RNG, like
// `TestWorld::rng`.
pub fn run_with<F, Fut, const ITER: usize>(f: F)
where
F: Fn() -> Fut + Send + Sync + 'static,
F: FnOnce() -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()>,
{
tokio::runtime::Builder::new_multi_thread()
Expand All @@ -301,11 +318,32 @@ pub(crate) mod test_executor {
#[allow(dead_code)]
pub fn run<F, Fut>(f: F)
where
F: Fn() -> Fut + Send + Sync + 'static,
F: FnOnce() -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()>,
{
run_with::<_, _, 1>(f);
}

#[allow(dead_code)]
pub fn run_with_seed<F, Fut>(seed: u64, f: F)
where
F: FnOnce(StdRng) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()>,
{
println!("Random seed {seed}");
let rng = StdRng::seed_from_u64(seed);
run(move || f(rng));
}

#[allow(dead_code)]
pub fn run_random<F, Fut>(f: F)
where
F: FnOnce(StdRng) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()>,
{
let seed = thread_rng().gen();
run_with_seed(seed, f);
}
}

pub const CRATE_NAME: &str = env!("CARGO_CRATE_NAME");
Expand Down
7 changes: 4 additions & 3 deletions ipa-core/src/protocol/ipa_prf/aggregation/breakdown_reveal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ where
#[cfg(all(test, any(unit_test, feature = "shuttle")))]
pub mod tests {
use futures::TryFutureExt;
use rand::{seq::SliceRandom, Rng};
use rand::seq::SliceRandom;

#[cfg(not(feature = "shuttle"))]
use crate::{ff::boolean_array::BA16, test_executor::run};
Expand All @@ -270,6 +270,7 @@ pub mod tests {
oprf_padding::PaddingParameters,
prf_sharding::{AttributionOutputsTestInput, SecretSharedAttributionOutputs},
},
rand::Rng,
secret_sharing::{
replicated::semi_honest::AdditiveShare as Replicated, BitDecomposed, TransposeFrom,
},
Expand All @@ -293,7 +294,7 @@ pub mod tests {
// (workers there are really slow).
run_with::<_, _, 3>(|| async {
let world = TestWorld::default();
let mut rng = rand::thread_rng();
let mut rng = world.rng();
let mut expectation = Vec::new();
for _ in 0..32 {
expectation.push(rng.gen_range(0u128..256));
Expand Down Expand Up @@ -346,7 +347,7 @@ pub mod tests {
type HV = BA16;
run(|| async {
let world = TestWorld::default();
let mut rng = rand::thread_rng();
let mut rng = world.rng();
let mut expectation = Vec::new();
for _ in 0..32 {
expectation.push(rng.gen_range(0u128..512));
Expand Down
51 changes: 25 additions & 26 deletions ipa-core/src/protocol/ipa_prf/aggregation/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -237,11 +237,10 @@ where

#[cfg(all(test, unit_test))]
pub mod tests {
use std::{array, cmp::min, iter::repeat_with};
use std::cmp::min;

use futures::{stream, StreamExt};
use proptest::prelude::*;
use rand::{rngs::StdRng, SeedableRng};

use super::aggregate_values;
use crate::{
Expand Down Expand Up @@ -487,8 +486,7 @@ pub mod tests {
#[derive(Debug)]
struct AggregatePropTestInputs {
inputs: Vec<[u32; PROP_BUCKETS]>,
expected: BitDecomposed<BA8>,
seed: u64,
expected: BitDecomposed<PropHistogramValue>,
len: usize,
tv_bits: usize,
}
Expand All @@ -503,37 +501,37 @@ pub mod tests {
(
len in 0..=max_len,
tv_bits in 0..=PROP_MAX_TV_BITS,
seed in any::<u64>(),
)
(
len in Just(len),
tv_bits in Just(tv_bits),
inputs in prop::collection::vec(prop::array::uniform(0u32..1 << tv_bits), len),
)
-> AggregatePropTestInputs {
let mut rng = StdRng::seed_from_u64(seed);
let mut expected = vec![0; PROP_BUCKETS];
let inputs = repeat_with(|| {
let row: [u32; PROP_BUCKETS] = array::from_fn(|_| rng.gen_range(0..1 << tv_bits));
for row in &inputs {
for (exp, val) in expected.iter_mut().zip(row) {
*exp = min(*exp + val, (1 << PropHistogramValue::BITS) - 1);
}
row
})
.take(len)
.collect();
}
Comment on lines +504 to +516
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These changes are not essential to this PR, but it seemed clearer not to have two different seeds, and using proptest for input generation enables using its other capabilities like shrinking.


let expected = input_row::<PROP_BUCKETS>(usize::try_from(PropHistogramValue::BITS).unwrap(), &expected)
.map(|x| x.into_iter().collect());

AggregatePropTestInputs {
inputs,
expected,
seed,
len,
tv_bits,
}
}
}

proptest! {
#[test]
fn aggregate_proptest(
input_struct in arb_aggregate_values_inputs(PROP_MAX_INPUT_LEN)
input_struct in arb_aggregate_values_inputs(PROP_MAX_INPUT_LEN),
seed in any::<u64>(),
) {
tokio::runtime::Runtime::new().unwrap().block_on(async {
let AggregatePropTestInputs {
Expand All @@ -545,18 +543,19 @@ pub mod tests {
let inputs = inputs.into_iter().map(move |row| {
Ok(input_row(tv_bits, &row))
});
let result : BitDecomposed<BA8> = TestWorld::default().upgraded_semi_honest(inputs, |ctx, inputs| {
let num_rows = inputs.len();
aggregate_values::<_, PropHistogramValue, PROP_BUCKETS>(
ctx,
stream::iter(inputs).boxed(),
num_rows,
None,
)
})
.await
.map(Result::unwrap)
.reconstruct_arr();
let result: BitDecomposed<PropHistogramValue> = TestWorld::with_seed(seed)
.upgraded_semi_honest(inputs, |ctx, inputs| {
let num_rows = inputs.len();
aggregate_values::<_, PropHistogramValue, PROP_BUCKETS>(
ctx,
stream::iter(inputs).boxed(),
num_rows,
None,
)
})
.await
.map(Result::unwrap)
.reconstruct_arr();

assert_eq!(result, expected);
});
Expand Down
Loading