From b628bf22c732874bb915eb936c0be3c7a2f58874 Mon Sep 17 00:00:00 2001 From: Andy Leiserson Date: Wed, 25 Sep 2024 20:41:44 -0700 Subject: [PATCH 1/7] Move some tests --- .../src/protocol/context/dzkp_validator.rs | 141 +++++++++++++++--- ipa-core/src/secret_sharing/vector/impls.rs | 101 ------------- 2 files changed, 123 insertions(+), 119 deletions(-) diff --git a/ipa-core/src/protocol/context/dzkp_validator.rs b/ipa-core/src/protocol/context/dzkp_validator.rs index 517d4db46..31821c592 100644 --- a/ipa-core/src/protocol/context/dzkp_validator.rs +++ b/ipa-core/src/protocol/context/dzkp_validator.rs @@ -827,31 +827,136 @@ mod tests { use bitvec::{order::Lsb0, prelude::BitArray, vec::BitVec}; use futures::{StreamExt, TryStreamExt}; use futures_util::stream::iter; - use proptest::{prop_compose, proptest, sample::select}; - use rand::{thread_rng, Rng}; + use proptest::{prelude::prop, prop_compose, proptest}; + use rand::{distributions::Standard, prelude::Distribution}; use crate::{ - error::Error, - ff::{boolean::Boolean, Fp61BitPrime}, - protocol::{ - basics::SecureMul, + error::Error, ff::{ + boolean::Boolean, boolean_array::{BooleanArray, BA16, BA20, BA256, BA32, BA64, BA8}, Fp61BitPrime + }, protocol::{ + basics::{select, BooleanArrayMul, SecureMul}, context::{ - dzkp_field::{DZKPCompatibleField, BLOCK_SIZE}, - dzkp_validator::{ + dzkp_field::{DZKPCompatibleField, BLOCK_SIZE}, dzkp_validator::{ Batch, DZKPValidator, Segment, SegmentEntry, BIT_ARRAY_LEN, TARGET_PROOF_SIZE, - }, - Context, UpgradableContext, TEST_DZKP_STEPS, + }, Context, DZKPUpgradedMaliciousContext, DZKPUpgradedSemiHonestContext, UpgradableContext, TEST_DZKP_STEPS }, Gate, RecordId, - }, - secret_sharing::{ - replicated::semi_honest::AdditiveShare as Replicated, IntoShares, SharedValue, - Vectorizable, - }, - seq_join::{seq_join, SeqJoin}, - test_fixture::{join3v, Reconstruct, Runner, TestWorld}, + }, rand::{thread_rng, Rng}, secret_sharing::{ + replicated::semi_honest::AdditiveShare as Replicated, + IntoShares, SharedValue, Vectorizable, + }, seq_join::{seq_join, SeqJoin}, sharding::NotSharded, test_fixture::{join3v, Reconstruct, Runner, TestWorld} }; + async fn test_simplest_circuit_semi_honest() + where + V: BooleanArray, + for<'a> Replicated: BooleanArrayMul>, + Standard: Distribution, + { + let world = TestWorld::default(); + let context = world.contexts(); + let mut rng = thread_rng(); + + let bit = rng.gen::(); + let a = rng.gen::(); + let b = rng.gen::(); + + let bit_shares = bit.share_with(&mut rng); + let a_shares = a.share_with(&mut rng); + let b_shares = b.share_with(&mut rng); + + let futures = zip(context.iter(), zip(bit_shares, zip(a_shares, b_shares))).map( + |(ctx, (bit_share, (a_share, b_share)))| async move { + let v = ctx.clone().dzkp_validator(TEST_DZKP_STEPS, 1); + let sh_ctx = v.context(); + + let result = select( + sh_ctx.set_total_records(1), + RecordId::from(0), + &bit_share, + &a_share, + &b_share, + ) + .await?; + + v.validate().await?; + + Ok::<_, Error>(result) + }, + ); + + let [ab0, ab1, ab2] = join3v(futures).await; + + let ab = [ab0, ab1, ab2].reconstruct(); + + assert_eq!(ab, if bit.into() { a } else { b }); + } + + #[tokio::test] + async fn simplest_circuit_semi_honest() { + test_simplest_circuit_semi_honest::().await; + test_simplest_circuit_semi_honest::().await; + test_simplest_circuit_semi_honest::().await; + test_simplest_circuit_semi_honest::().await; + test_simplest_circuit_semi_honest::().await; + test_simplest_circuit_semi_honest::().await; + } + + async fn test_simplest_circuit_malicious() + where + V: BooleanArray, + for<'a> Replicated: BooleanArrayMul>, + Standard: Distribution, + { + let world = TestWorld::default(); + let context = world.malicious_contexts(); + let mut rng = thread_rng(); + + let bit = rng.gen::(); + let a = rng.gen::(); + let b = rng.gen::(); + + let bit_shares = bit.share_with(&mut rng); + let a_shares = a.share_with(&mut rng); + let b_shares = b.share_with(&mut rng); + + let futures = zip(context.iter(), zip(bit_shares, zip(a_shares, b_shares))).map( + |(ctx, (bit_share, (a_share, b_share)))| async move { + let v = ctx.clone().dzkp_validator(TEST_DZKP_STEPS, 1); + let m_ctx = v.context(); + + let result = select( + m_ctx.set_total_records(1), + RecordId::from(0), + &bit_share, + &a_share, + &b_share, + ) + .await?; + + v.validate().await?; + + Ok::<_, Error>(result) + }, + ); + + let [ab0, ab1, ab2] = join3v(futures).await; + + let ab = [ab0, ab1, ab2].reconstruct(); + + assert_eq!(ab, if bit.into() { a } else { b }); + } + + #[tokio::test] + async fn simplest_circuit_malicious() { + test_simplest_circuit_malicious::().await; + test_simplest_circuit_malicious::().await; + test_simplest_circuit_malicious::().await; + test_simplest_circuit_malicious::().await; + test_simplest_circuit_malicious::().await; + test_simplest_circuit_malicious::().await; + } + #[tokio::test] async fn dzkp_malicious() { const COUNT: usize = 32; @@ -1022,7 +1127,7 @@ mod tests { } prop_compose! { - fn arb_count_and_chunk()((log_count, log_multiplication_amount) in select(&[(5,5),(7,5),(5,8)])) -> (usize, usize) { + fn arb_count_and_chunk()((log_count, log_multiplication_amount) in prop::sample::select(&[(5,5),(7,5),(5,8)])) -> (usize, usize) { (1usize<(); - let a = rng.gen::<$vec>(); - let b = rng.gen::<$vec>(); - - let bit_shares = bit.share_with(&mut rng); - let a_shares = a.share_with(&mut rng); - let b_shares = b.share_with(&mut rng); - - let futures = zip(context.iter(), zip(bit_shares, zip(a_shares, b_shares))) - .map(|(ctx, (bit_share, (a_share, b_share)))| async move { - let v = ctx.clone().dzkp_validator(TEST_DZKP_STEPS, 1); - let m_ctx = v.context(); - - let result = select( - m_ctx.set_total_records(1), - RecordId::from(0), - &bit_share, - &a_share, - &b_share, - ) - .await?; - - v.validate().await?; - - Ok::<_, Error>(result) - }); - - let [ab0, ab1, ab2] = join3v(futures).await; - - let ab = [ab0, ab1, ab2].reconstruct(); - - assert_eq!(ab, if bit.into() { a } else { b }); - } - - #[tokio::test] - async fn simplest_circuit_semi_honest() { - let world = TestWorld::default(); - let context = world.contexts(); - let mut rng = thread_rng(); - - let bit = rng.gen::(); - let a = rng.gen::<$vec>(); - let b = rng.gen::<$vec>(); - - let bit_shares = bit.share_with(&mut rng); - let a_shares = a.share_with(&mut rng); - let b_shares = b.share_with(&mut rng); - - let futures = zip(context.iter(), zip(bit_shares, zip(a_shares, b_shares))) - .map(|(ctx, (bit_share, (a_share, b_share)))| async move { - let v = ctx.clone().dzkp_validator(TEST_DZKP_STEPS, 1); - let sh_ctx = v.context(); - - let result = select( - sh_ctx.set_total_records(1), - RecordId::from(0), - &bit_share, - &a_share, - &b_share, - ) - .await?; - - v.validate().await?; - - Ok::<_, Error>(result) - }); - - let [ab0, ab1, ab2] = join3v(futures).await; - - let ab = [ab0, ab1, ab2].reconstruct(); - - assert_eq!(ab, if bit.into() { a } else { b }); - } - } } }; } From a4c6f03f8674328d1c879be10d25c89047c35606 Mon Sep 17 00:00:00 2001 From: Andy Leiserson Date: Wed, 25 Sep 2024 21:08:12 -0700 Subject: [PATCH 2/7] Make the DZKP batching proptest more comprehensive --- .../src/protocol/context/dzkp_validator.rs | 95 +++++++++++-------- 1 file changed, 56 insertions(+), 39 deletions(-) diff --git a/ipa-core/src/protocol/context/dzkp_validator.rs b/ipa-core/src/protocol/context/dzkp_validator.rs index 31821c592..87f442c77 100644 --- a/ipa-core/src/protocol/context/dzkp_validator.rs +++ b/ipa-core/src/protocol/context/dzkp_validator.rs @@ -827,27 +827,39 @@ mod tests { use bitvec::{order::Lsb0, prelude::BitArray, vec::BitVec}; use futures::{StreamExt, TryStreamExt}; use futures_util::stream::iter; - use proptest::{prelude::prop, prop_compose, proptest}; + use proptest::{prelude::Strategy, prop_oneof, proptest}; use rand::{distributions::Standard, prelude::Distribution}; use crate::{ - error::Error, ff::{ - boolean::Boolean, boolean_array::{BooleanArray, BA16, BA20, BA256, BA32, BA64, BA8}, Fp61BitPrime - }, protocol::{ + error::Error, + ff::{ + boolean::Boolean, + boolean_array::{BooleanArray, BA16, BA20, BA256, BA32, BA64, BA8}, + Fp61BitPrime, + }, + protocol::{ basics::{select, BooleanArrayMul, SecureMul}, context::{ - dzkp_field::{DZKPCompatibleField, BLOCK_SIZE}, dzkp_validator::{ + dzkp_field::{DZKPCompatibleField, BLOCK_SIZE}, + dzkp_validator::{ Batch, DZKPValidator, Segment, SegmentEntry, BIT_ARRAY_LEN, TARGET_PROOF_SIZE, - }, Context, DZKPUpgradedMaliciousContext, DZKPUpgradedSemiHonestContext, UpgradableContext, TEST_DZKP_STEPS + }, + Context, DZKPUpgradedMaliciousContext, DZKPUpgradedSemiHonestContext, + UpgradableContext, TEST_DZKP_STEPS, }, Gate, RecordId, - }, rand::{thread_rng, Rng}, secret_sharing::{ - replicated::semi_honest::AdditiveShare as Replicated, - IntoShares, SharedValue, Vectorizable, - }, seq_join::{seq_join, SeqJoin}, sharding::NotSharded, test_fixture::{join3v, Reconstruct, Runner, TestWorld} + }, + rand::{thread_rng, Rng}, + secret_sharing::{ + replicated::semi_honest::AdditiveShare as Replicated, IntoShares, SharedValue, + Vectorizable, + }, + seq_join::seq_join, + sharding::NotSharded, + test_fixture::{join3v, Reconstruct, Runner, TestWorld}, }; - async fn test_simplest_circuit_semi_honest() + async fn test_select_semi_honest() where V: BooleanArray, for<'a> Replicated: BooleanArrayMul>, @@ -893,16 +905,16 @@ mod tests { } #[tokio::test] - async fn simplest_circuit_semi_honest() { - test_simplest_circuit_semi_honest::().await; - test_simplest_circuit_semi_honest::().await; - test_simplest_circuit_semi_honest::().await; - test_simplest_circuit_semi_honest::().await; - test_simplest_circuit_semi_honest::().await; - test_simplest_circuit_semi_honest::().await; + async fn select_semi_honest() { + test_select_semi_honest::().await; + test_select_semi_honest::().await; + test_select_semi_honest::().await; + test_select_semi_honest::().await; + test_select_semi_honest::().await; + test_select_semi_honest::().await; } - async fn test_simplest_circuit_malicious() + async fn test_select_malicious() where V: BooleanArray, for<'a> Replicated: BooleanArrayMul>, @@ -948,17 +960,17 @@ mod tests { } #[tokio::test] - async fn simplest_circuit_malicious() { - test_simplest_circuit_malicious::().await; - test_simplest_circuit_malicious::().await; - test_simplest_circuit_malicious::().await; - test_simplest_circuit_malicious::().await; - test_simplest_circuit_malicious::().await; - test_simplest_circuit_malicious::().await; + async fn select_malicious() { + test_select_malicious::().await; + test_select_malicious::().await; + test_select_malicious::().await; + test_select_malicious::().await; + test_select_malicious::().await; + test_select_malicious::().await; } #[tokio::test] - async fn dzkp_malicious() { + async fn two_multiplies_malicious() { const COUNT: usize = 32; let mut rng = thread_rng(); @@ -1019,8 +1031,8 @@ mod tests { } /// test for testing `validated_seq_join` - /// similar to `complex_circuit` in `validator.rs` - async fn complex_circuit_dzkp( + /// similar to `complex_circuit` in `validator.rs` (which has a more detailed comment) + async fn chained_multiplies_dzkp( count: usize, max_multiplications_per_gate: usize, ) -> Result<(), Error> { @@ -1050,7 +1062,7 @@ mod tests { .map(|(ctx, input_shares)| async move { let v = ctx .set_total_records(count - 1) - .dzkp_validator(TEST_DZKP_STEPS, ctx.active_work().get()); + .dzkp_validator(TEST_DZKP_STEPS, max_multiplications_per_gate); let m_ctx = v.context(); let m_results = v @@ -1126,19 +1138,24 @@ mod tests { Ok(()) } - prop_compose! { - fn arb_count_and_chunk()((log_count, log_multiplication_amount) in prop::sample::select(&[(5,5),(7,5),(5,8)])) -> (usize, usize) { - (1usize< impl Strategy { + prop_oneof![1usize..=512, (1usize..=9).prop_map(|i| 1usize << i)] + } + + fn max_multiplications_per_gate_strategy() -> impl Strategy { + prop_oneof![1usize..=128, (1usize..=7).prop_map(|i| 1usize << i)] } proptest! { #[test] - fn test_complex_circuit_dzkp((count, multiplication_amount) in arb_count_and_chunk()){ - let future = async { - let _ = complex_circuit_dzkp(count, multiplication_amount).await; - }; - tokio::runtime::Runtime::new().unwrap().block_on(future); + fn test_chained_multiplies_dzkp( + record_count in record_count_strategy(), + max_multiplications_per_gate in max_multiplications_per_gate_strategy(), + ) { + println!("record_count {record_count} batch {max_multiplications_per_gate}"); + tokio::runtime::Runtime::new().unwrap().block_on(async { + chained_multiplies_dzkp(record_count, max_multiplications_per_gate).await.unwrap(); + }); } } From 1a6e388449c6fbfe2b94f2166526d15969c89d23 Mon Sep 17 00:00:00 2001 From: Andy Leiserson Date: Wed, 25 Sep 2024 09:15:17 -0700 Subject: [PATCH 3/7] Make active_work at least records_per_batch --- ipa-core/src/protocol/context/batcher.rs | 4 ++++ ipa-core/src/protocol/context/dzkp_malicious.rs | 11 ++++++++++- ipa-core/src/protocol/context/malicious.rs | 15 ++++++++++++--- ipa-core/src/protocol/context/validator.rs | 2 +- ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs | 4 +--- 5 files changed, 28 insertions(+), 8 deletions(-) diff --git a/ipa-core/src/protocol/context/batcher.rs b/ipa-core/src/protocol/context/batcher.rs index f7021d988..2cd00be01 100644 --- a/ipa-core/src/protocol/context/batcher.rs +++ b/ipa-core/src/protocol/context/batcher.rs @@ -97,6 +97,10 @@ impl<'a, B> Batcher<'a, B> { self.total_records = self.total_records.overwrite(total_records.into()); } + pub fn records_per_batch(&self) -> usize { + self.records_per_batch + } + fn batch_offset(&self, record_id: RecordId) -> usize { let batch_index = usize::from(record_id) / self.records_per_batch; batch_index diff --git a/ipa-core/src/protocol/context/dzkp_malicious.rs b/ipa-core/src/protocol/context/dzkp_malicious.rs index 73cfda40c..fd37232b0 100644 --- a/ipa-core/src/protocol/context/dzkp_malicious.rs +++ b/ipa-core/src/protocol/context/dzkp_malicious.rs @@ -1,4 +1,5 @@ use std::{ + cmp::max, fmt::{Debug, Formatter}, num::NonZeroUsize, }; @@ -29,6 +30,7 @@ use crate::{ pub struct DZKPUpgraded<'a> { validator_inner: Weak>, base_ctx: MaliciousContext<'a>, + active_work: NonZeroUsize, } impl<'a> DZKPUpgraded<'a> { @@ -36,9 +38,16 @@ impl<'a> DZKPUpgraded<'a> { validator_inner: &Arc>, base_ctx: MaliciousContext<'a>, ) -> Self { + // Adjust active_work to be at least records_per_batch. If it is less, we will + // stall, since every record in the batch remains incomplete until the batch is + // validated. + let records_per_batch = validator_inner.batcher.lock().unwrap().records_per_batch(); + let active_work = + NonZeroUsize::new(max(base_ctx.active_work().get(), records_per_batch)).unwrap(); Self { validator_inner: Arc::downgrade(validator_inner), base_ctx, + active_work, } } @@ -130,7 +139,7 @@ impl<'a> super::Context for DZKPUpgraded<'a> { impl<'a> SeqJoin for DZKPUpgraded<'a> { fn active_work(&self) -> NonZeroUsize { - self.base_ctx.active_work() + self.active_work } } diff --git a/ipa-core/src/protocol/context/malicious.rs b/ipa-core/src/protocol/context/malicious.rs index a93a8edfb..ceee87fdf 100644 --- a/ipa-core/src/protocol/context/malicious.rs +++ b/ipa-core/src/protocol/context/malicious.rs @@ -1,5 +1,6 @@ use std::{ any::type_name, + cmp::max, fmt::{Debug, Formatter}, num::NonZeroUsize, }; @@ -174,13 +175,21 @@ pub(super) type MacBatcher<'a, F> = Mutex { batch: Weak>, base_ctx: Context<'a>, + active_work: NonZeroUsize, } impl<'a, F: ExtendableField> Upgraded<'a, F> { - pub(super) fn new(batch: &Arc>, ctx: Context<'a>) -> Self { + pub(super) fn new(batch: &Arc>, base_ctx: Context<'a>) -> Self { + // Adjust active_work to be at least records_per_batch. The MAC validator + // currently configures the batcher with records_per_batch = active_work, which + // makes this adjustment a no-op, but we do it this way to match the DZKP validator. + let records_per_batch = batch.lock().unwrap().records_per_batch(); + let active_work = + NonZeroUsize::new(max(base_ctx.active_work().get(), records_per_batch)).unwrap(); Self { batch: Arc::downgrade(batch), - base_ctx: ctx, + base_ctx, + active_work, } } @@ -297,7 +306,7 @@ impl<'a, F: ExtendableField> super::Context for Upgraded<'a, F> { impl<'a, F: ExtendableField> SeqJoin for Upgraded<'a, F> { fn active_work(&self) -> NonZeroUsize { - self.base_ctx.active_work() + self.active_work } } diff --git a/ipa-core/src/protocol/context/validator.rs b/ipa-core/src/protocol/context/validator.rs index 33303fb9b..e57ae3c6a 100644 --- a/ipa-core/src/protocol/context/validator.rs +++ b/ipa-core/src/protocol/context/validator.rs @@ -217,7 +217,7 @@ impl<'a, F: ExtendableField> BatchValidator<'a, F> { // TODO: Right now we set the batch work to be equal to active_work, // but it does not need to be. We can make this configurable if needed. - let records_per_batch = ctx.active_work().get().min(total_records.get()); + let records_per_batch = ctx.active_work().get(); Self { protocol_ctx: ctx.narrow(&Step::MaliciousProtocol), diff --git a/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs b/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs index 267631da0..a745d5ea7 100644 --- a/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs @@ -510,9 +510,7 @@ where protocol: &Step::Attribute, validate: &Step::AttributeValidate, }, - // The size of a single batch should not exceed the active work limit, - // otherwise it will stall - std::cmp::min(sh_ctx.active_work().get(), chunk_size), + chunk_size, ); dzkp_validator.set_total_records(TotalRecords::specified(histogram[1]).unwrap()); let ctx_for_row_number = set_up_contexts(&dzkp_validator.context(), histogram)?; From 706dcbea218d39e35e64fd7273561316bdf50546 Mon Sep 17 00:00:00 2001 From: Andy Leiserson Date: Thu, 26 Sep 2024 11:22:45 -0700 Subject: [PATCH 4/7] Improvements to batching proptest --- .../src/protocol/context/dzkp_validator.rs | 112 ++++++++++++++++-- 1 file changed, 100 insertions(+), 12 deletions(-) diff --git a/ipa-core/src/protocol/context/dzkp_validator.rs b/ipa-core/src/protocol/context/dzkp_validator.rs index 87f442c77..2ff635d02 100644 --- a/ipa-core/src/protocol/context/dzkp_validator.rs +++ b/ipa-core/src/protocol/context/dzkp_validator.rs @@ -825,9 +825,13 @@ mod tests { }; use bitvec::{order::Lsb0, prelude::BitArray, vec::BitVec}; - use futures::{StreamExt, TryStreamExt}; + use futures::{stream, StreamExt, TryStreamExt}; use futures_util::stream::iter; - use proptest::{prelude::Strategy, prop_oneof, proptest}; + use proptest::{ + prelude::{Just, Strategy}, + prop_compose, prop_oneof, proptest, + test_runner::Config as ProptestConfig, + }; use rand::{distributions::Standard, prelude::Distribution}; use crate::{ @@ -1030,6 +1034,51 @@ mod tests { } } + /// Similar to `test_select_malicious`, but operating on vectors + async fn multi_select_malicious(count: usize, max_multiplications_per_gate: usize) + where + V: BooleanArray, + for<'a> Replicated: BooleanArrayMul>, + Standard: Distribution, + { + let mut rng = thread_rng(); + + let bit: Vec = repeat_with(|| rng.gen::()).take(count).collect(); + let a: Vec = repeat_with(|| rng.gen()).take(count).collect(); + let b: Vec = repeat_with(|| rng.gen()).take(count).collect(); + + let [ab0, ab1, ab2]: [Vec>; 3] = TestWorld::default() + .malicious( + zip(bit.clone(), zip(a.clone(), b.clone())), + |ctx, inputs| async move { + let v = ctx + .set_total_records(count) + .dzkp_validator(TEST_DZKP_STEPS, max_multiplications_per_gate); + let m_ctx = v.context(); + + v.validated_seq_join(stream::iter(inputs).enumerate().map( + |(i, (bit_share, (a_share, b_share)))| { + let m_ctx = m_ctx.clone(); + async move { + select(m_ctx, RecordId::from(i), &bit_share, &a_share, &b_share) + .await + } + }, + )) + .try_collect() + .await + }, + ) + .await + .map(Result::unwrap); + + let ab: Vec = [ab0, ab1, ab2].reconstruct(); + + for i in 0..count { + assert_eq!(ab[i], if bit[i].into() { a[i] } else { b[i] }); + } + } + /// test for testing `validated_seq_join` /// similar to `complex_circuit` in `validator.rs` (which has a more detailed comment) async fn chained_multiplies_dzkp( @@ -1139,23 +1188,62 @@ mod tests { } fn record_count_strategy() -> impl Strategy { - prop_oneof![1usize..=512, (1usize..=9).prop_map(|i| 1usize << i)] + // The chained_multiplies test has count - 1 records, so 1 is not a valid input size. + // It is for multi_select though. + prop_oneof![2usize..=512, (1u32..=9).prop_map(|i| 1usize << i)] + } + + fn max_multiplications_per_gate_strategy(record_count: usize) -> impl Strategy { + let max_max_mults = record_count.min(128); + prop_oneof![ + 1usize..=max_max_mults, + (0u32..=max_max_mults.ilog2()).prop_map(|i| 1usize << i) + ] } - fn max_multiplications_per_gate_strategy() -> impl Strategy { - prop_oneof![1usize..=128, (1usize..=7).prop_map(|i| 1usize << i)] + prop_compose! { + fn batching() + (record_count in record_count_strategy()) + (record_count in Just(record_count), max_mults in max_multiplications_per_gate_strategy(record_count)) + -> (usize, usize) + { + (record_count, max_mults) + } } proptest! { + #![proptest_config(ProptestConfig::with_cases(50))] #[test] - fn test_chained_multiplies_dzkp( - record_count in record_count_strategy(), - max_multiplications_per_gate in max_multiplications_per_gate_strategy(), - ) { + fn batching_proptest((record_count, max_multiplications_per_gate) in batching()) { println!("record_count {record_count} batch {max_multiplications_per_gate}"); - tokio::runtime::Runtime::new().unwrap().block_on(async { - chained_multiplies_dzkp(record_count, max_multiplications_per_gate).await.unwrap(); - }); + if record_count / max_multiplications_per_gate >= 192 { + // TODO: #1269, or even if we don't fix that, don't hardcode the limit. + println!("skipping config because batch count exceeds limit of 192"); + } + // This condition is correct only for active_work = 16 and record size of 1 byte. + else if max_multiplications_per_gate != 1 && max_multiplications_per_gate % 16 != 0 { + // TODO: #1300, read_size | batch_size. + // Note: for active work < 2048, read size matches active work. + + // Besides read_size | batch_size, there is also a constraint + // something like active_work > read_size + batch_size - 1. + println!("skipping config due to read_size vs. batch_size constraints"); + } else { + tokio::runtime::Runtime::new().unwrap().block_on(async { + chained_multiplies_dzkp(record_count, max_multiplications_per_gate).await.unwrap(); + /* + multi_select_malicious::(record_count, max_multiplications_per_gate).await; + multi_select_malicious::(record_count, max_multiplications_per_gate).await; + multi_select_malicious::(record_count, max_multiplications_per_gate).await; + */ + multi_select_malicious::(record_count, max_multiplications_per_gate).await; + /* + multi_select_malicious::(record_count, max_multiplications_per_gate).await; + multi_select_malicious::(record_count, max_multiplications_per_gate).await; + multi_select_malicious::(record_count, max_multiplications_per_gate).await; + */ + }); + } } } From c5c0ccf8a35d5dbbb04a9161228eb5aef0d70bd8 Mon Sep 17 00:00:00 2001 From: Andy Leiserson Date: Thu, 26 Sep 2024 11:22:35 -0700 Subject: [PATCH 5/7] Revise based on PR feedback and offline discussion --- .../src/protocol/context/dzkp_malicious.rs | 25 ++++++++++++++----- .../src/protocol/context/dzkp_validator.rs | 4 ++- ipa-core/src/protocol/context/malicious.rs | 23 +++++++++-------- 3 files changed, 34 insertions(+), 18 deletions(-) diff --git a/ipa-core/src/protocol/context/dzkp_malicious.rs b/ipa-core/src/protocol/context/dzkp_malicious.rs index fd37232b0..70dd3d2af 100644 --- a/ipa-core/src/protocol/context/dzkp_malicious.rs +++ b/ipa-core/src/protocol/context/dzkp_malicious.rs @@ -1,5 +1,4 @@ use std::{ - cmp::max, fmt::{Debug, Formatter}, num::NonZeroUsize, }; @@ -38,12 +37,26 @@ impl<'a> DZKPUpgraded<'a> { validator_inner: &Arc>, base_ctx: MaliciousContext<'a>, ) -> Self { - // Adjust active_work to be at least records_per_batch. If it is less, we will - // stall, since every record in the batch remains incomplete until the batch is - // validated. let records_per_batch = validator_inner.batcher.lock().unwrap().records_per_batch(); - let active_work = - NonZeroUsize::new(max(base_ctx.active_work().get(), records_per_batch)).unwrap(); + let active_work = if records_per_batch == 1 { + // If records_per_batch is 1, let active_work be anything. This only happens + // in tests; there shouldn't be a risk of deadlocks with one record per + // batch; and UnorderedReceiver capacity (which is set from active_work) + // must be at least two. + base_ctx.active_work() + } else { + // Adjust active_work to match records_per_batch. If it is less, we will + // certainly stall, since every record in the batch remains incomplete until + // the batch is validated. It is possible that it can be larger, but making + // it the same seems safer for now. + let active_work = NonZeroUsize::new(records_per_batch).unwrap(); + tracing::debug!( + "Changed active_work from {} to {} to match batch size", + base_ctx.active_work().get(), + active_work, + ); + active_work + }; Self { validator_inner: Arc::downgrade(validator_inner), base_ctx, diff --git a/ipa-core/src/protocol/context/dzkp_validator.rs b/ipa-core/src/protocol/context/dzkp_validator.rs index 2ff635d02..f5586336f 100644 --- a/ipa-core/src/protocol/context/dzkp_validator.rs +++ b/ipa-core/src/protocol/context/dzkp_validator.rs @@ -838,7 +838,7 @@ mod tests { error::Error, ff::{ boolean::Boolean, - boolean_array::{BooleanArray, BA16, BA20, BA256, BA32, BA64, BA8}, + boolean_array::{BooleanArray, BA16, BA20, BA256, BA3, BA32, BA64, BA8}, Fp61BitPrime, }, protocol::{ @@ -910,6 +910,7 @@ mod tests { #[tokio::test] async fn select_semi_honest() { + test_select_semi_honest::().await; test_select_semi_honest::().await; test_select_semi_honest::().await; test_select_semi_honest::().await; @@ -965,6 +966,7 @@ mod tests { #[tokio::test] async fn select_malicious() { + test_select_malicious::().await; test_select_malicious::().await; test_select_malicious::().await; test_select_malicious::().await; diff --git a/ipa-core/src/protocol/context/malicious.rs b/ipa-core/src/protocol/context/malicious.rs index ceee87fdf..18b9b8e29 100644 --- a/ipa-core/src/protocol/context/malicious.rs +++ b/ipa-core/src/protocol/context/malicious.rs @@ -1,6 +1,5 @@ use std::{ any::type_name, - cmp::max, fmt::{Debug, Formatter}, num::NonZeroUsize, }; @@ -175,21 +174,23 @@ pub(super) type MacBatcher<'a, F> = Mutex { batch: Weak>, base_ctx: Context<'a>, - active_work: NonZeroUsize, } impl<'a, F: ExtendableField> Upgraded<'a, F> { - pub(super) fn new(batch: &Arc>, base_ctx: Context<'a>) -> Self { - // Adjust active_work to be at least records_per_batch. The MAC validator - // currently configures the batcher with records_per_batch = active_work, which - // makes this adjustment a no-op, but we do it this way to match the DZKP validator. + pub(super) fn new(batch: &Arc>, ctx: Context<'a>) -> Self { + // The DZKP malicious context adjusts active_work to match records_per_batch. + // The MAC validator currently configures the batcher with records_per_batch = + // active_work. If the latter behavior changes, this code may need to be + // updated. let records_per_batch = batch.lock().unwrap().records_per_batch(); - let active_work = - NonZeroUsize::new(max(base_ctx.active_work().get(), records_per_batch)).unwrap(); + let active_work = ctx.active_work().get(); + assert_eq!( + records_per_batch, active_work, + "Expect MAC validation batch size ({records_per_batch}) to match active work ({active_work})", + ); Self { batch: Arc::downgrade(batch), - base_ctx, - active_work, + base_ctx: ctx, } } @@ -306,7 +307,7 @@ impl<'a, F: ExtendableField> super::Context for Upgraded<'a, F> { impl<'a, F: ExtendableField> SeqJoin for Upgraded<'a, F> { fn active_work(&self) -> NonZeroUsize { - self.active_work + self.base_ctx.active_work() } } From 966160fcdd12effd9687a159768e25cc6df179a6 Mon Sep 17 00:00:00 2001 From: Andy Leiserson Date: Thu, 26 Sep 2024 21:20:45 -0700 Subject: [PATCH 6/7] Fix a bug in the batcher --- ipa-core/src/protocol/context/batcher.rs | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/ipa-core/src/protocol/context/batcher.rs b/ipa-core/src/protocol/context/batcher.rs index 2cd00be01..cdbfddbce 100644 --- a/ipa-core/src/protocol/context/batcher.rs +++ b/ipa-core/src/protocol/context/batcher.rs @@ -114,7 +114,7 @@ impl<'a, B> Batcher<'a, B> { while self.batches.len() <= batch_offset { let (validation_result, _) = watch::channel::(false); let state = BatchState { - batch: (self.batch_constructor)(self.first_batch + batch_offset), + batch: (self.batch_constructor)(self.first_batch + self.batches.len()), validation_result, pending_count: 0, pending_records: bitvec![0; self.records_per_batch], @@ -296,6 +296,23 @@ mod tests { ); } + #[test] + fn makes_batches_out_of_order() { + // Regression test for a bug where, when adding batches i..j to fill in a gap in + // the batch deque prior to out-of-order requested batch j, the batcher passed + // batch index `j` to the constructor for all of them, as opposed to the correct + // sequence of indices i..=j. + + let batcher = Batcher::new(1, 2, Box::new(std::convert::identity)); + let mut batcher = batcher.lock().unwrap(); + + batcher.get_batch(RecordId::from(1)); + batcher.get_batch(RecordId::from(0)); + + assert_eq!(batcher.get_batch(RecordId::from(0)).batch, 0); + assert_eq!(batcher.get_batch(RecordId::from(1)).batch, 1); + } + #[tokio::test] async fn validates_batches() { let batcher = Batcher::new(2, 4, Box::new(|_| Vec::new())); From d2512f1a2cfa24c2af58d7700f534586b3cc89ef Mon Sep 17 00:00:00 2001 From: Andy Leiserson Date: Fri, 27 Sep 2024 09:43:32 -0700 Subject: [PATCH 7/7] Restore the batch size kludge for attribution --- ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs b/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs index a745d5ea7..9a1f8f278 100644 --- a/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs @@ -510,7 +510,9 @@ where protocol: &Step::Attribute, validate: &Step::AttributeValidate, }, - chunk_size, + // TODO: this should not be necessary, but probably can't be removed + // until we align read_size with the batch size. + std::cmp::min(sh_ctx.active_work().get(), chunk_size), ); dzkp_validator.set_total_records(TotalRecords::specified(histogram[1]).unwrap()); let ctx_for_row_number = set_up_contexts(&dzkp_validator.context(), histogram)?;