From 812f09463ffcf60458a67d69a4b5a47e2fee4edf Mon Sep 17 00:00:00 2001 From: Andy Leiserson Date: Fri, 25 Oct 2024 18:32:18 -0700 Subject: [PATCH] Seed management to improve reproducibility Fixes #1321 --- DEVELOPMENT.md | 34 +++- .../src/helpers/buffers/ordering_sender.rs | 12 +- ipa-core/src/lib.rs | 43 ++++- .../ipa_prf/aggregation/breakdown_reveal.rs | 7 +- .../src/protocol/ipa_prf/aggregation/mod.rs | 51 +++-- ipa-core/src/test_fixture/world.rs | 177 ++++++++++++++++-- 6 files changed, 267 insertions(+), 57 deletions(-) diff --git a/DEVELOPMENT.md b/DEVELOPMENT.md index a6ad2f50c..f0ed42064 100644 --- a/DEVELOPMENT.md +++ b/DEVELOPMENT.md @@ -41,4 +41,36 @@ While the computation is happening, Management will call the query_status API un MPC requires thousands of steps to be executed and coordinated across helpers. Each of these calls is represented as a single HTTP call (this becomes more relevant during MPC computation). The service uses HTTP2 multiplexing capabilities, where multiple requests are being sent and received during the same connection. -**Work in progress** \ No newline at end of file +**Work in progress** + +# Testing + +## Randomness + +Random tests can increase coverage by generating unforeseen corner cases. Even with +detailed error messages, the output of a failed random test may not provide enough +information to see what has gone wrong. When this happens, it is important to be +able to re-run the failing case with additional diagnostic output or with candidate +fixes. To make this possible, all random values used in the test should be derived +from a random seed that is logged in the output of a failed test. + +Using a random generator provided by `rand::thread_rng` will not typically achieve +reproducibility. Instead, tests should obtain a random number generator by calling +`TestWorld::rng`. An example of such a test is +`breakdown_reveal::tests::semi_honest_happy_path`. To run a test with a particular seed, +pass the seed to `TestWorld::with_seed` (or `TestWorldConfig::with_seed`). + +The [`shuttle`](shuttle.md) concurrency test framework provides its own random number +generator and reproducibility mechanisms. The `ipa_core::rand` module automatically +exports either symbols from the `rand` crate or the `shuttle` equivalents based on the +build configuration. If using `TestWorld::rng`, the switch to the `shuttle` RNG is +handled automatically in `TestWorld`. In tests that do not use `TestWorld`, the +`run_random` helper will automatically use the appropriate RNG, and log a seed if using +the standard RNG. An example of such a test is `ordering_sender::test::shuffle_fp31`. +To run a test with a particular seed, use `run_with_seed`. + +The `proptest` framework also has its own random number generator and reproducibility +mechanisms, but the `proptest` RNG is not integrated with `TestWorld`. When using +`proptest`, it is recommended to create a random `u64` seed in the proptest-generated +inputs and pass that seed to `TestWorld::with_seed` (or `TestWorldConfig::with_seed`). +An example of such a test is `aggregate_proptest`. diff --git a/ipa-core/src/helpers/buffers/ordering_sender.rs b/ipa-core/src/helpers/buffers/ordering_sender.rs index 7c21bbde6..256a76d1d 100644 --- a/ipa-core/src/helpers/buffers/ordering_sender.rs +++ b/ipa-core/src/helpers/buffers/ordering_sender.rs @@ -531,10 +531,9 @@ mod test { use crate::{ ff::{Fp31, Fp32BitPrime, Gf9Bit, PrimeField, Serializable, U128Conversions}, helpers::MpcMessage, - rand::thread_rng, secret_sharing::SharedValue, sync::Arc, - test_executor::run, + test_executor::{run, run_random}, }; fn sender() -> Arc { @@ -725,9 +724,9 @@ mod test { } /// Shuffle `count` indices. - pub fn shuffle_indices(count: usize) -> Vec { + pub fn shuffle_indices(count: usize, rng: &mut impl Rng) -> Vec { let mut indices = (0..count).collect::>(); - indices.shuffle(&mut thread_rng()); + indices.shuffle(rng); indices } @@ -737,11 +736,10 @@ mod test { const COUNT: usize = 16; const SZ: usize = <::Size as Unsigned>::USIZE; - run(|| async { - let mut rng = thread_rng(); + run_random(|mut rng| async move { let mut values = Vec::with_capacity(COUNT); values.resize_with(COUNT, || rng.gen::()); - let indices = shuffle_indices(COUNT); + let indices = shuffle_indices(COUNT, &mut rng); let sender = sender::(); let (_, (), output) = join3( diff --git a/ipa-core/src/lib.rs b/ipa-core/src/lib.rs index 345bbe0ae..f6467b87c 100644 --- a/ipa-core/src/lib.rs +++ b/ipa-core/src/lib.rs @@ -262,6 +262,8 @@ pub(crate) mod executor { pub(crate) mod test_executor { use std::future::Future; + use shuttle::rand::{rngs::ThreadRng, thread_rng}; + pub fn run(f: F) where F: Fn() -> Fut + Send + Sync + 'static, @@ -277,15 +279,31 @@ pub(crate) mod test_executor { { shuttle::check_random(move || shuttle::future::block_on(f()), ITER); } + + pub fn run_random(f: F) + where + F: Fn(ThreadRng) -> Fut + Send + Sync + 'static, + Fut: Future, + { + run(move || f(thread_rng())); + } } #[cfg(all(test, not(feature = "shuttle")))] pub(crate) mod test_executor { use std::future::Future; + use rand::{rngs::StdRng, thread_rng, Rng, SeedableRng}; + + // These routines use `FnOnce` because it is easier than dealing with lifetimes of + // `&mut rng` borrows in futures. If there were a need to support multiple + // iterations (or to make the API use `Fn` to match the shuttle version), the + // simplest strategy might be to seed per-iteration RNGs from a primary RNG, like + // `TestWorld::rng`. + pub fn run_with(f: F) where - F: Fn() -> Fut + Send + Sync + 'static, + F: FnOnce() -> Fut + Send + Sync + 'static, Fut: Future, { tokio::runtime::Builder::new_multi_thread() @@ -301,11 +319,32 @@ pub(crate) mod test_executor { #[allow(dead_code)] pub fn run(f: F) where - F: Fn() -> Fut + Send + Sync + 'static, + F: FnOnce() -> Fut + Send + Sync + 'static, Fut: Future, { run_with::<_, _, 1>(f); } + + #[allow(dead_code)] + pub fn run_with_seed(seed: u64, f: F) + where + F: FnOnce(StdRng) -> Fut + Send + Sync + 'static, + Fut: Future, + { + println!("Random seed {seed}"); + let rng = StdRng::seed_from_u64(seed); + run(move || f(rng)); + } + + #[allow(dead_code)] + pub fn run_random(f: F) + where + F: FnOnce(StdRng) -> Fut + Send + Sync + 'static, + Fut: Future, + { + let seed = thread_rng().gen(); + run_with_seed(seed, f); + } } pub const CRATE_NAME: &str = env!("CARGO_CRATE_NAME"); diff --git a/ipa-core/src/protocol/ipa_prf/aggregation/breakdown_reveal.rs b/ipa-core/src/protocol/ipa_prf/aggregation/breakdown_reveal.rs index b0b17396a..d3ec2ca3e 100644 --- a/ipa-core/src/protocol/ipa_prf/aggregation/breakdown_reveal.rs +++ b/ipa-core/src/protocol/ipa_prf/aggregation/breakdown_reveal.rs @@ -255,7 +255,7 @@ where #[cfg(all(test, any(unit_test, feature = "shuttle")))] pub mod tests { use futures::TryFutureExt; - use rand::{seq::SliceRandom, Rng}; + use rand::seq::SliceRandom; #[cfg(not(feature = "shuttle"))] use crate::{ff::boolean_array::BA16, test_executor::run}; @@ -270,6 +270,7 @@ pub mod tests { oprf_padding::PaddingParameters, prf_sharding::{AttributionOutputsTestInput, SecretSharedAttributionOutputs}, }, + rand::Rng, secret_sharing::{ replicated::semi_honest::AdditiveShare as Replicated, BitDecomposed, TransposeFrom, }, @@ -293,7 +294,7 @@ pub mod tests { // (workers there are really slow). run_with::<_, _, 3>(|| async { let world = TestWorld::default(); - let mut rng = rand::thread_rng(); + let mut rng = world.rng(); let mut expectation = Vec::new(); for _ in 0..32 { expectation.push(rng.gen_range(0u128..256)); @@ -346,7 +347,7 @@ pub mod tests { type HV = BA16; run(|| async { let world = TestWorld::default(); - let mut rng = rand::thread_rng(); + let mut rng = world.rng(); let mut expectation = Vec::new(); for _ in 0..32 { expectation.push(rng.gen_range(0u128..512)); diff --git a/ipa-core/src/protocol/ipa_prf/aggregation/mod.rs b/ipa-core/src/protocol/ipa_prf/aggregation/mod.rs index a5ca281e6..0f05b4b4d 100644 --- a/ipa-core/src/protocol/ipa_prf/aggregation/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/aggregation/mod.rs @@ -237,11 +237,10 @@ where #[cfg(all(test, unit_test))] pub mod tests { - use std::{array, cmp::min, iter::repeat_with}; + use std::cmp::min; use futures::{stream, StreamExt}; use proptest::prelude::*; - use rand::{rngs::StdRng, SeedableRng}; use super::aggregate_values; use crate::{ @@ -487,8 +486,7 @@ pub mod tests { #[derive(Debug)] struct AggregatePropTestInputs { inputs: Vec<[u32; PROP_BUCKETS]>, - expected: BitDecomposed, - seed: u64, + expected: BitDecomposed, len: usize, tv_bits: usize, } @@ -503,20 +501,19 @@ pub mod tests { ( len in 0..=max_len, tv_bits in 0..=PROP_MAX_TV_BITS, - seed in any::(), + ) + ( + len in Just(len), + tv_bits in Just(tv_bits), + inputs in prop::collection::vec(prop::array::uniform(0u32..1 << tv_bits), len), ) -> AggregatePropTestInputs { - let mut rng = StdRng::seed_from_u64(seed); let mut expected = vec![0; PROP_BUCKETS]; - let inputs = repeat_with(|| { - let row: [u32; PROP_BUCKETS] = array::from_fn(|_| rng.gen_range(0..1 << tv_bits)); + for row in &inputs { for (exp, val) in expected.iter_mut().zip(row) { *exp = min(*exp + val, (1 << PropHistogramValue::BITS) - 1); } - row - }) - .take(len) - .collect(); + } let expected = input_row::(usize::try_from(PropHistogramValue::BITS).unwrap(), &expected) .map(|x| x.into_iter().collect()); @@ -524,16 +521,17 @@ pub mod tests { AggregatePropTestInputs { inputs, expected, - seed, len, tv_bits, } } } + proptest! { #[test] fn aggregate_proptest( - input_struct in arb_aggregate_values_inputs(PROP_MAX_INPUT_LEN) + input_struct in arb_aggregate_values_inputs(PROP_MAX_INPUT_LEN), + seed in any::(), ) { tokio::runtime::Runtime::new().unwrap().block_on(async { let AggregatePropTestInputs { @@ -545,18 +543,19 @@ pub mod tests { let inputs = inputs.into_iter().map(move |row| { Ok(input_row(tv_bits, &row)) }); - let result : BitDecomposed = TestWorld::default().upgraded_semi_honest(inputs, |ctx, inputs| { - let num_rows = inputs.len(); - aggregate_values::<_, PropHistogramValue, PROP_BUCKETS>( - ctx, - stream::iter(inputs).boxed(), - num_rows, - None, - ) - }) - .await - .map(Result::unwrap) - .reconstruct_arr(); + let result: BitDecomposed = TestWorld::with_seed(seed) + .upgraded_semi_honest(inputs, |ctx, inputs| { + let num_rows = inputs.len(); + aggregate_values::<_, PropHistogramValue, PROP_BUCKETS>( + ctx, + stream::iter(inputs).boxed(), + num_rows, + None, + ) + }) + .await + .map(Result::unwrap) + .reconstruct_arr(); assert_eq!(result, expected); }); diff --git a/ipa-core/src/test_fixture/world.rs b/ipa-core/src/test_fixture/world.rs index 5c23cb2d8..8a3ff3a95 100644 --- a/ipa-core/src/test_fixture/world.rs +++ b/ipa-core/src/test_fixture/world.rs @@ -1,7 +1,13 @@ // We have quite a bit of code that is only used when descriptive-gate is enabled. #![allow(dead_code)] use std::{ - array::from_fn, borrow::Borrow, fmt::Debug, io::stdout, iter, iter::zip, marker::PhantomData, + array::from_fn, + borrow::Borrow, + fmt::Debug, + io::stdout, + iter::{self, zip}, + marker::PhantomData, + sync::Mutex, }; use async_trait::async_trait; @@ -9,7 +15,7 @@ use futures::{future::join_all, stream::FuturesOrdered, Future, StreamExt}; use rand::{ distributions::{Distribution, Standard}, rngs::StdRng, - thread_rng, Rng, RngCore, SeedableRng, + Rng, RngCore, SeedableRng, }; use tracing::{Instrument, Level, Span}; @@ -29,6 +35,7 @@ use crate::{ prss::Endpoint as PrssEndpoint, Gate, QueryId, RecordId, }, + rand::thread_rng, secret_sharing::{ replicated::malicious::{ DowngradeMalicious, ExtendableField, ThisCodeIsAuthorizedToDowngradeFromMalicious, @@ -66,7 +73,7 @@ pub trait ShardingScheme { /// Helper trait to parametrize [`Runner`] trait based on the sharding scheme chosen. The whole /// purpose of it is to be able to say for sharded runs, the input must be in a form of a [`Vec`] pub trait RunnerInput: Send { - fn share(self) -> [S::Container; 3]; + fn share_with(self, rng: &mut R) -> [S::Container; 3]; } /// Trait that defines how helper inputs are distributed across shards. The simplest implementation @@ -93,6 +100,9 @@ pub struct WithShards { pub struct TestWorld { shards: Box<[ShardWorld]>, metrics_handle: MetricsHandle, + // Using a mutex here is not as unfortunate as it might initially appear, because we + // only use this RNG as a source of seeds for other RNGs. See `fn rng`. + rng: Mutex, gate_vendor: Box, _shard_network: InMemoryShardNetwork, _phantom: PhantomData, @@ -190,6 +200,13 @@ impl Default for TestWorld { } } +impl TestWorld { + #[must_use] + pub fn with_seed(seed: u64) -> Self { + Self::new_with(TestWorldConfig::default().with_seed(seed)) + } +} + impl TestWorld> { /// For backward compatibility, this method must have a different name than [`non_sharded`] method. /// @@ -301,8 +318,10 @@ impl TestWorld { #[must_use] pub fn with_config(config: &TestWorldConfig) -> Self { logging::setup(); + // Print to stdout so that it appears in test runs only on failure. println!("TestWorld random seed {seed}", seed = config.seed); + let mut rng = StdRng::seed_from_u64(config.seed); let shard_count = ShardIndex::try_from(S::SHARDS).unwrap(); let shard_network = @@ -314,7 +333,7 @@ impl TestWorld { ShardWorld::new( S::bind_shard(shard), config, - u64::from(shard), + &mut rng, shard_network.shard_transports(shard), ) }) @@ -324,6 +343,7 @@ impl TestWorld { Self { shards, metrics_handle: MetricsHandle::new(config.metrics_level), + rng: Mutex::new(rng), gate_vendor: gate_vendor(config.initial_gate.clone()), _shard_network: shard_network, _phantom: PhantomData, @@ -339,6 +359,17 @@ impl TestWorld { fn next_gate(&self) -> Gate { self.gate_vendor.next() } + + /// Return a new `Rng` seeded from the primary `TestWorld` RNG. + /// + /// ## Panics + /// If the mutex is poisoned. + #[must_use] + pub fn rng(&self) -> impl Rng { + // We need to use the `TestWorld` RNG within the `Runner` helpers, which + // unfortunately take `&self`. + StdRng::from_seed(self.rng.lock().unwrap().gen()) + } } impl Default for TestWorldConfig { @@ -385,8 +416,8 @@ impl TestWorldConfig { } impl + Send, A: Send> RunnerInput for I { - fn share(self) -> [A; 3] { - I::share(self) + fn share_with(self, rng: &mut R) -> [A; 3] { + I::share_with(self, rng) } } @@ -396,8 +427,8 @@ where A: Send, D: Distribute, { - fn share(self) -> [Vec; 3] { - I::share(self) + fn share_with(self, rng: &mut R) -> [Vec; 3] { + I::share_with(self, rng) } } @@ -499,7 +530,8 @@ impl Runner> R: Future + Send, { let shards = self.shards(); - let [h1, h2, h3]: [[Vec; SHARDS]; 3] = input.share().map(D::distribute); + let mut rng = self.rng(); + let [h1, h2, h3]: [[Vec; SHARDS]; 3] = input.share_with(&mut rng).map(D::distribute); let gate = self.next_gate(); // No clippy, you're wrong, it is not redundant, it allows shard_fn to be `Copy` @@ -549,7 +581,8 @@ impl Runner> R: Future + Send, { let shards = self.shards(); - let [h1, h2, h3]: [[Vec; SHARDS]; 3] = input.share().map(D::distribute); + let mut rng = self.rng(); + let [h1, h2, h3]: [[Vec; SHARDS]; 3] = input.share_with(&mut rng).map(D::distribute); let gate = self.next_gate(); // todo!() @@ -619,7 +652,7 @@ impl Runner for TestWorld { ShardWorld::::run_either( self.contexts(), self.metrics_handle.span(), - input.share(), + input.share_with(&mut self.rng()), helper_fn, ) .await @@ -658,7 +691,7 @@ impl Runner for TestWorld { ShardWorld::::run_either( self.malicious_contexts(), self.metrics_handle.span(), - input.share(), + input.share_with(&mut self.rng()), helper_fn, ) .await @@ -766,10 +799,10 @@ impl ShardWorld { pub fn new( shard_info: B, config: &TestWorldConfig, - shard_seed: u64, + rng: &mut StdRng, transports: [InMemoryTransport; 3], ) -> Self { - let participants = make_participants(&mut StdRng::seed_from_u64(config.seed + shard_seed)); + let participants = make_participants(rng); let network = InMemoryMpcNetwork::with_stream_interceptor( InMemoryMpcNetwork::noop_handlers(), &config.stream_interceptor, @@ -897,26 +930,33 @@ mod tests { use futures_util::future::try_join4; use crate::{ - ff::{boolean::Boolean, boolean_array::BA3, Field, Fp31, U128Conversions}, + ff::{ + boolean::Boolean, + boolean_array::{BA3, BA8}, + ArrayAccess, Field, Fp31, U128Conversions, + }, helpers::{ in_memory_config::{MaliciousHelper, MaliciousHelperContext}, Direction, Role, TotalRecords, }, protocol::{ basics::SecureMul, + boolean::step::EightBitStep, context::{ - dzkp_validator::DZKPValidator, upgrade::Upgradable, Context, UpgradableContext, - UpgradedContext, Validator, TEST_DZKP_STEPS, + dzkp_validator::DZKPValidator, upgrade::Upgradable, Context, DZKPContext, + UpgradableContext, UpgradedContext, Validator, TEST_DZKP_STEPS, }, + ipa_prf::boolean_ops::addition_sequential::integer_add, prss::SharedRandomness, RecordId, }, + rand::Rng, secret_sharing::{ replicated::{semi_honest::AdditiveShare, ReplicatedSecretSharing}, SharedValue, }, sharding::ShardConfiguration, - test_executor::run, + test_executor::{run, run_random}, test_fixture::{world::WithShards, Reconstruct, Runner, TestWorld, TestWorldConfig}, }; @@ -1101,4 +1141,105 @@ mod tests { assert_eq!(input, r); }); } + + #[test] + fn mac_reproducible_from_seed() { + run_random(|mut rng| async move { + async fn test(seed: u64) -> Vec { + let u_and_w = Arc::new(Mutex::new(vec![])); + let u_and_w_ref = Arc::clone(&u_and_w); + + let mut config = TestWorldConfig::default(); + config.seed = seed; + config.stream_interceptor = MaliciousHelper::new( + Role::H1, + config.role_assignment(), + move |ctx: &MaliciousHelperContext, data: &mut Vec| { + if ctx.gate.as_ref().contains("propagate_u_and_w") { + u_and_w_ref.lock().unwrap().extend(data.iter()); + } + }, + ); + + let world = TestWorld::with_config(&config); + let mut rng = world.rng(); + let input: (Fp31, Fp31) = (rng.gen(), rng.gen()); + world + .malicious(input, |ctx, input| async move { + let validator = ctx.set_total_records(1).validator(); + let ctx = validator.context(); + let (a_upgraded, b_upgraded) = input + .clone() + .upgrade(ctx.clone(), RecordId::FIRST) + .await + .unwrap(); + a_upgraded + .multiply(&b_upgraded, ctx.narrow("multiply"), RecordId::FIRST) + .await + .unwrap(); + ctx.validate_record(RecordId::FIRST).await.unwrap(); + }) + .await; + + let result = u_and_w.lock().unwrap().clone(); + result + } + + let seed = rng.gen(); + let first_result = test(seed).await; + let second_result = test(seed).await; + + assert_eq!(first_result, second_result); + }); + } + + #[test] + fn zkp_reproducible_from_seed() { + run_random(|mut rng| async move { + async fn test(seed: u64) -> Vec { + let proof_diff = Arc::new(Mutex::new(vec![])); + let proof_diff_ref = Arc::clone(&proof_diff); + + let mut config = TestWorldConfig::default(); + config.seed = seed; + config.stream_interceptor = MaliciousHelper::new( + Role::H1, + config.role_assignment(), + move |ctx: &MaliciousHelperContext, data: &mut Vec| { + if ctx.gate.as_ref().contains("verify_proof/diff") { + proof_diff_ref.lock().unwrap().extend(data.iter()); + } + }, + ); + + let world = TestWorld::with_config(&config); + let mut rng = world.rng(); + let input: (BA8, BA8) = (rng.gen(), rng.gen()); + world + .malicious(input, |ctx, input| async move { + let validator = ctx.set_total_records(1).dzkp_validator(TEST_DZKP_STEPS, 8); + let ctx = validator.context(); + integer_add::<_, EightBitStep, 1>( + ctx.clone(), + RecordId::FIRST, + &input.0.to_bits(), + &input.1.to_bits(), + ) + .await + .unwrap(); + ctx.validate_record(RecordId::FIRST).await.unwrap(); + }) + .await; + + let result = proof_diff.lock().unwrap().clone(); + result + } + + let seed = rng.gen(); + let first_result = test(seed).await; + let second_result = test(seed).await; + + assert_eq!(first_result, second_result); + }); + } }