Skip to content

Commit

Permalink
Seed management to improve reproducibility
Browse files Browse the repository at this point in the history
  • Loading branch information
andyleiserson committed Oct 28, 2024
1 parent 57e2c63 commit 812f094
Show file tree
Hide file tree
Showing 6 changed files with 267 additions and 57 deletions.
34 changes: 33 additions & 1 deletion DEVELOPMENT.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,36 @@ While the computation is happening, Management will call the query_status API un

MPC requires thousands of steps to be executed and coordinated across helpers. Each of these calls is represented as a single HTTP call (this becomes more relevant during MPC computation). The service uses HTTP2 multiplexing capabilities, where multiple requests are being sent and received during the same connection.

**Work in progress**
**Work in progress**

# Testing

## Randomness

Random tests can increase coverage by generating unforeseen corner cases. Even with
detailed error messages, the output of a failed random test may not provide enough
information to see what has gone wrong. When this happens, it is important to be
able to re-run the failing case with additional diagnostic output or with candidate
fixes. To make this possible, all random values used in the test should be derived
from a random seed that is logged in the output of a failed test.

Using a random generator provided by `rand::thread_rng` will not typically achieve
reproducibility. Instead, tests should obtain a random number generator by calling
`TestWorld::rng`. An example of such a test is
`breakdown_reveal::tests::semi_honest_happy_path`. To run a test with a particular seed,
pass the seed to `TestWorld::with_seed` (or `TestWorldConfig::with_seed`).

The [`shuttle`](shuttle.md) concurrency test framework provides its own random number
generator and reproducibility mechanisms. The `ipa_core::rand` module automatically
exports either symbols from the `rand` crate or the `shuttle` equivalents based on the
build configuration. If using `TestWorld::rng`, the switch to the `shuttle` RNG is
handled automatically in `TestWorld`. In tests that do not use `TestWorld`, the
`run_random` helper will automatically use the appropriate RNG, and log a seed if using
the standard RNG. An example of such a test is `ordering_sender::test::shuffle_fp31`.
To run a test with a particular seed, use `run_with_seed`.

The `proptest` framework also has its own random number generator and reproducibility
mechanisms, but the `proptest` RNG is not integrated with `TestWorld`. When using
`proptest`, it is recommended to create a random `u64` seed in the proptest-generated
inputs and pass that seed to `TestWorld::with_seed` (or `TestWorldConfig::with_seed`).
An example of such a test is `aggregate_proptest`.
12 changes: 5 additions & 7 deletions ipa-core/src/helpers/buffers/ordering_sender.rs
Original file line number Diff line number Diff line change
Expand Up @@ -531,10 +531,9 @@ mod test {
use crate::{
ff::{Fp31, Fp32BitPrime, Gf9Bit, PrimeField, Serializable, U128Conversions},
helpers::MpcMessage,
rand::thread_rng,
secret_sharing::SharedValue,
sync::Arc,
test_executor::run,
test_executor::{run, run_random},
};

fn sender<F: PrimeField>() -> Arc<OrderingSender> {
Expand Down Expand Up @@ -725,9 +724,9 @@ mod test {
}

/// Shuffle `count` indices.
pub fn shuffle_indices(count: usize) -> Vec<usize> {
pub fn shuffle_indices(count: usize, rng: &mut impl Rng) -> Vec<usize> {
let mut indices = (0..count).collect::<Vec<_>>();
indices.shuffle(&mut thread_rng());
indices.shuffle(rng);
indices
}

Expand All @@ -737,11 +736,10 @@ mod test {
const COUNT: usize = 16;
const SZ: usize = <<Fp31 as Serializable>::Size as Unsigned>::USIZE;

run(|| async {
let mut rng = thread_rng();
run_random(|mut rng| async move {
let mut values = Vec::with_capacity(COUNT);
values.resize_with(COUNT, || rng.gen::<Fp31>());
let indices = shuffle_indices(COUNT);
let indices = shuffle_indices(COUNT, &mut rng);

let sender = sender::<Fp31>();
let (_, (), output) = join3(
Expand Down
43 changes: 41 additions & 2 deletions ipa-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,8 @@ pub(crate) mod executor {
pub(crate) mod test_executor {
use std::future::Future;

use shuttle::rand::{rngs::ThreadRng, thread_rng};

pub fn run<F, Fut>(f: F)
where
F: Fn() -> Fut + Send + Sync + 'static,
Expand All @@ -277,15 +279,31 @@ pub(crate) mod test_executor {
{
shuttle::check_random(move || shuttle::future::block_on(f()), ITER);
}

pub fn run_random<F, Fut>(f: F)
where
F: Fn(ThreadRng) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()>,
{
run(move || f(thread_rng()));
}
}

#[cfg(all(test, not(feature = "shuttle")))]
pub(crate) mod test_executor {
use std::future::Future;

use rand::{rngs::StdRng, thread_rng, Rng, SeedableRng};

// These routines use `FnOnce` because it is easier than dealing with lifetimes of
// `&mut rng` borrows in futures. If there were a need to support multiple
// iterations (or to make the API use `Fn` to match the shuttle version), the
// simplest strategy might be to seed per-iteration RNGs from a primary RNG, like
// `TestWorld::rng`.

pub fn run_with<F, Fut, const ITER: usize>(f: F)
where
F: Fn() -> Fut + Send + Sync + 'static,
F: FnOnce() -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()>,
{
tokio::runtime::Builder::new_multi_thread()
Expand All @@ -301,11 +319,32 @@ pub(crate) mod test_executor {
#[allow(dead_code)]
pub fn run<F, Fut>(f: F)
where
F: Fn() -> Fut + Send + Sync + 'static,
F: FnOnce() -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()>,
{
run_with::<_, _, 1>(f);
}

#[allow(dead_code)]
pub fn run_with_seed<F, Fut>(seed: u64, f: F)
where
F: FnOnce(StdRng) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()>,
{
println!("Random seed {seed}");
let rng = StdRng::seed_from_u64(seed);
run(move || f(rng));
}

#[allow(dead_code)]
pub fn run_random<F, Fut>(f: F)
where
F: FnOnce(StdRng) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()>,
{
let seed = thread_rng().gen();
run_with_seed(seed, f);
}
}

pub const CRATE_NAME: &str = env!("CARGO_CRATE_NAME");
Expand Down
7 changes: 4 additions & 3 deletions ipa-core/src/protocol/ipa_prf/aggregation/breakdown_reveal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ where
#[cfg(all(test, any(unit_test, feature = "shuttle")))]
pub mod tests {
use futures::TryFutureExt;
use rand::{seq::SliceRandom, Rng};
use rand::seq::SliceRandom;

#[cfg(not(feature = "shuttle"))]
use crate::{ff::boolean_array::BA16, test_executor::run};
Expand All @@ -270,6 +270,7 @@ pub mod tests {
oprf_padding::PaddingParameters,
prf_sharding::{AttributionOutputsTestInput, SecretSharedAttributionOutputs},
},
rand::Rng,
secret_sharing::{
replicated::semi_honest::AdditiveShare as Replicated, BitDecomposed, TransposeFrom,
},
Expand All @@ -293,7 +294,7 @@ pub mod tests {
// (workers there are really slow).
run_with::<_, _, 3>(|| async {
let world = TestWorld::default();
let mut rng = rand::thread_rng();
let mut rng = world.rng();
let mut expectation = Vec::new();
for _ in 0..32 {
expectation.push(rng.gen_range(0u128..256));
Expand Down Expand Up @@ -346,7 +347,7 @@ pub mod tests {
type HV = BA16;
run(|| async {
let world = TestWorld::default();
let mut rng = rand::thread_rng();
let mut rng = world.rng();
let mut expectation = Vec::new();
for _ in 0..32 {
expectation.push(rng.gen_range(0u128..512));
Expand Down
51 changes: 25 additions & 26 deletions ipa-core/src/protocol/ipa_prf/aggregation/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -237,11 +237,10 @@ where

#[cfg(all(test, unit_test))]
pub mod tests {
use std::{array, cmp::min, iter::repeat_with};
use std::cmp::min;

use futures::{stream, StreamExt};
use proptest::prelude::*;
use rand::{rngs::StdRng, SeedableRng};

use super::aggregate_values;
use crate::{
Expand Down Expand Up @@ -487,8 +486,7 @@ pub mod tests {
#[derive(Debug)]
struct AggregatePropTestInputs {
inputs: Vec<[u32; PROP_BUCKETS]>,
expected: BitDecomposed<BA8>,
seed: u64,
expected: BitDecomposed<PropHistogramValue>,
len: usize,
tv_bits: usize,
}
Expand All @@ -503,37 +501,37 @@ pub mod tests {
(
len in 0..=max_len,
tv_bits in 0..=PROP_MAX_TV_BITS,
seed in any::<u64>(),
)
(
len in Just(len),
tv_bits in Just(tv_bits),
inputs in prop::collection::vec(prop::array::uniform(0u32..1 << tv_bits), len),
)
-> AggregatePropTestInputs {
let mut rng = StdRng::seed_from_u64(seed);
let mut expected = vec![0; PROP_BUCKETS];
let inputs = repeat_with(|| {
let row: [u32; PROP_BUCKETS] = array::from_fn(|_| rng.gen_range(0..1 << tv_bits));
for row in &inputs {
for (exp, val) in expected.iter_mut().zip(row) {
*exp = min(*exp + val, (1 << PropHistogramValue::BITS) - 1);
}
row
})
.take(len)
.collect();
}

let expected = input_row::<PROP_BUCKETS>(usize::try_from(PropHistogramValue::BITS).unwrap(), &expected)
.map(|x| x.into_iter().collect());

AggregatePropTestInputs {
inputs,
expected,
seed,
len,
tv_bits,
}
}
}

proptest! {
#[test]
fn aggregate_proptest(
input_struct in arb_aggregate_values_inputs(PROP_MAX_INPUT_LEN)
input_struct in arb_aggregate_values_inputs(PROP_MAX_INPUT_LEN),
seed in any::<u64>(),
) {
tokio::runtime::Runtime::new().unwrap().block_on(async {
let AggregatePropTestInputs {
Expand All @@ -545,18 +543,19 @@ pub mod tests {
let inputs = inputs.into_iter().map(move |row| {
Ok(input_row(tv_bits, &row))
});
let result : BitDecomposed<BA8> = TestWorld::default().upgraded_semi_honest(inputs, |ctx, inputs| {
let num_rows = inputs.len();
aggregate_values::<_, PropHistogramValue, PROP_BUCKETS>(
ctx,
stream::iter(inputs).boxed(),
num_rows,
None,
)
})
.await
.map(Result::unwrap)
.reconstruct_arr();
let result: BitDecomposed<PropHistogramValue> = TestWorld::with_seed(seed)
.upgraded_semi_honest(inputs, |ctx, inputs| {
let num_rows = inputs.len();
aggregate_values::<_, PropHistogramValue, PROP_BUCKETS>(
ctx,
stream::iter(inputs).boxed(),
num_rows,
None,
)
})
.await
.map(Result::unwrap)
.reconstruct_arr();

assert_eq!(result, expected);
});
Expand Down
Loading

0 comments on commit 812f094

Please sign in to comment.