Skip to content

Commit

Permalink
Merge pull request #1316 from andyleiserson/active-work
Browse files Browse the repository at this point in the history
Make active_work match records_per_batch
  • Loading branch information
andyleiserson authored Sep 27, 2024
2 parents e3b0243 + d2512f1 commit 59ca3e7
Show file tree
Hide file tree
Showing 7 changed files with 288 additions and 124 deletions.
23 changes: 22 additions & 1 deletion ipa-core/src/protocol/context/batcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,10 @@ impl<'a, B> Batcher<'a, B> {
self.total_records = self.total_records.overwrite(total_records.into());
}

pub fn records_per_batch(&self) -> usize {
self.records_per_batch
}

fn batch_offset(&self, record_id: RecordId) -> usize {
let batch_index = usize::from(record_id) / self.records_per_batch;
batch_index
Expand All @@ -110,7 +114,7 @@ impl<'a, B> Batcher<'a, B> {
while self.batches.len() <= batch_offset {
let (validation_result, _) = watch::channel::<bool>(false);
let state = BatchState {
batch: (self.batch_constructor)(self.first_batch + batch_offset),
batch: (self.batch_constructor)(self.first_batch + self.batches.len()),
validation_result,
pending_count: 0,
pending_records: bitvec![0; self.records_per_batch],
Expand Down Expand Up @@ -292,6 +296,23 @@ mod tests {
);
}

#[test]
fn makes_batches_out_of_order() {
// Regression test for a bug where, when adding batches i..j to fill in a gap in
// the batch deque prior to out-of-order requested batch j, the batcher passed
// batch index `j` to the constructor for all of them, as opposed to the correct
// sequence of indices i..=j.

let batcher = Batcher::new(1, 2, Box::new(std::convert::identity));
let mut batcher = batcher.lock().unwrap();

batcher.get_batch(RecordId::from(1));
batcher.get_batch(RecordId::from(0));

assert_eq!(batcher.get_batch(RecordId::from(0)).batch, 0);
assert_eq!(batcher.get_batch(RecordId::from(1)).batch, 1);
}

#[tokio::test]
async fn validates_batches() {
let batcher = Batcher::new(2, 4, Box::new(|_| Vec::new()));
Expand Down
24 changes: 23 additions & 1 deletion ipa-core/src/protocol/context/dzkp_malicious.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,38 @@ use crate::{
pub struct DZKPUpgraded<'a> {
validator_inner: Weak<MaliciousDZKPValidatorInner<'a>>,
base_ctx: MaliciousContext<'a>,
active_work: NonZeroUsize,
}

impl<'a> DZKPUpgraded<'a> {
pub(super) fn new(
validator_inner: &Arc<MaliciousDZKPValidatorInner<'a>>,
base_ctx: MaliciousContext<'a>,
) -> Self {
let records_per_batch = validator_inner.batcher.lock().unwrap().records_per_batch();
let active_work = if records_per_batch == 1 {
// If records_per_batch is 1, let active_work be anything. This only happens
// in tests; there shouldn't be a risk of deadlocks with one record per
// batch; and UnorderedReceiver capacity (which is set from active_work)
// must be at least two.
base_ctx.active_work()
} else {
// Adjust active_work to match records_per_batch. If it is less, we will
// certainly stall, since every record in the batch remains incomplete until
// the batch is validated. It is possible that it can be larger, but making
// it the same seems safer for now.
let active_work = NonZeroUsize::new(records_per_batch).unwrap();
tracing::debug!(
"Changed active_work from {} to {} to match batch size",
base_ctx.active_work().get(),
active_work,
);
active_work
};
Self {
validator_inner: Arc::downgrade(validator_inner),
base_ctx,
active_work,
}
}

Expand Down Expand Up @@ -130,7 +152,7 @@ impl<'a> super::Context for DZKPUpgraded<'a> {

impl<'a> SeqJoin for DZKPUpgraded<'a> {
fn active_work(&self) -> NonZeroUsize {
self.base_ctx.active_work()
self.active_work
}
}

Expand Down
248 changes: 230 additions & 18 deletions ipa-core/src/protocol/context/dzkp_validator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -825,35 +825,158 @@ mod tests {
};

use bitvec::{order::Lsb0, prelude::BitArray, vec::BitVec};
use futures::{StreamExt, TryStreamExt};
use futures::{stream, StreamExt, TryStreamExt};
use futures_util::stream::iter;
use proptest::{prop_compose, proptest, sample::select};
use rand::{thread_rng, Rng};
use proptest::{
prelude::{Just, Strategy},
prop_compose, prop_oneof, proptest,
test_runner::Config as ProptestConfig,
};
use rand::{distributions::Standard, prelude::Distribution};

use crate::{
error::Error,
ff::{boolean::Boolean, Fp61BitPrime},
ff::{
boolean::Boolean,
boolean_array::{BooleanArray, BA16, BA20, BA256, BA3, BA32, BA64, BA8},
Fp61BitPrime,
},
protocol::{
basics::SecureMul,
basics::{select, BooleanArrayMul, SecureMul},
context::{
dzkp_field::{DZKPCompatibleField, BLOCK_SIZE},
dzkp_validator::{
Batch, DZKPValidator, Segment, SegmentEntry, BIT_ARRAY_LEN, TARGET_PROOF_SIZE,
},
Context, UpgradableContext, TEST_DZKP_STEPS,
Context, DZKPUpgradedMaliciousContext, DZKPUpgradedSemiHonestContext,
UpgradableContext, TEST_DZKP_STEPS,
},
Gate, RecordId,
},
rand::{thread_rng, Rng},
secret_sharing::{
replicated::semi_honest::AdditiveShare as Replicated, IntoShares, SharedValue,
Vectorizable,
},
seq_join::{seq_join, SeqJoin},
seq_join::seq_join,
sharding::NotSharded,
test_fixture::{join3v, Reconstruct, Runner, TestWorld},
};

async fn test_select_semi_honest<V>()
where
V: BooleanArray,
for<'a> Replicated<V>: BooleanArrayMul<DZKPUpgradedSemiHonestContext<'a, NotSharded>>,
Standard: Distribution<V>,
{
let world = TestWorld::default();
let context = world.contexts();
let mut rng = thread_rng();

let bit = rng.gen::<Boolean>();
let a = rng.gen::<V>();
let b = rng.gen::<V>();

let bit_shares = bit.share_with(&mut rng);
let a_shares = a.share_with(&mut rng);
let b_shares = b.share_with(&mut rng);

let futures = zip(context.iter(), zip(bit_shares, zip(a_shares, b_shares))).map(
|(ctx, (bit_share, (a_share, b_share)))| async move {
let v = ctx.clone().dzkp_validator(TEST_DZKP_STEPS, 1);
let sh_ctx = v.context();

let result = select(
sh_ctx.set_total_records(1),
RecordId::from(0),
&bit_share,
&a_share,
&b_share,
)
.await?;

v.validate().await?;

Ok::<_, Error>(result)
},
);

let [ab0, ab1, ab2] = join3v(futures).await;

let ab = [ab0, ab1, ab2].reconstruct();

assert_eq!(ab, if bit.into() { a } else { b });
}

#[tokio::test]
async fn dzkp_malicious() {
async fn select_semi_honest() {
test_select_semi_honest::<BA3>().await;
test_select_semi_honest::<BA8>().await;
test_select_semi_honest::<BA16>().await;
test_select_semi_honest::<BA20>().await;
test_select_semi_honest::<BA32>().await;
test_select_semi_honest::<BA64>().await;
test_select_semi_honest::<BA256>().await;
}

async fn test_select_malicious<V>()
where
V: BooleanArray,
for<'a> Replicated<V>: BooleanArrayMul<DZKPUpgradedMaliciousContext<'a>>,
Standard: Distribution<V>,
{
let world = TestWorld::default();
let context = world.malicious_contexts();
let mut rng = thread_rng();

let bit = rng.gen::<Boolean>();
let a = rng.gen::<V>();
let b = rng.gen::<V>();

let bit_shares = bit.share_with(&mut rng);
let a_shares = a.share_with(&mut rng);
let b_shares = b.share_with(&mut rng);

let futures = zip(context.iter(), zip(bit_shares, zip(a_shares, b_shares))).map(
|(ctx, (bit_share, (a_share, b_share)))| async move {
let v = ctx.clone().dzkp_validator(TEST_DZKP_STEPS, 1);
let m_ctx = v.context();

let result = select(
m_ctx.set_total_records(1),
RecordId::from(0),
&bit_share,
&a_share,
&b_share,
)
.await?;

v.validate().await?;

Ok::<_, Error>(result)
},
);

let [ab0, ab1, ab2] = join3v(futures).await;

let ab = [ab0, ab1, ab2].reconstruct();

assert_eq!(ab, if bit.into() { a } else { b });
}

#[tokio::test]
async fn select_malicious() {
test_select_malicious::<BA3>().await;
test_select_malicious::<BA8>().await;
test_select_malicious::<BA16>().await;
test_select_malicious::<BA20>().await;
test_select_malicious::<BA32>().await;
test_select_malicious::<BA64>().await;
test_select_malicious::<BA256>().await;
}

#[tokio::test]
async fn two_multiplies_malicious() {
const COUNT: usize = 32;
let mut rng = thread_rng();

Expand Down Expand Up @@ -913,9 +1036,54 @@ mod tests {
}
}

/// Similar to `test_select_malicious`, but operating on vectors
async fn multi_select_malicious<V>(count: usize, max_multiplications_per_gate: usize)
where
V: BooleanArray,
for<'a> Replicated<V>: BooleanArrayMul<DZKPUpgradedMaliciousContext<'a>>,
Standard: Distribution<V>,
{
let mut rng = thread_rng();

let bit: Vec<Boolean> = repeat_with(|| rng.gen::<Boolean>()).take(count).collect();
let a: Vec<V> = repeat_with(|| rng.gen()).take(count).collect();
let b: Vec<V> = repeat_with(|| rng.gen()).take(count).collect();

let [ab0, ab1, ab2]: [Vec<Replicated<V>>; 3] = TestWorld::default()
.malicious(
zip(bit.clone(), zip(a.clone(), b.clone())),
|ctx, inputs| async move {
let v = ctx
.set_total_records(count)
.dzkp_validator(TEST_DZKP_STEPS, max_multiplications_per_gate);
let m_ctx = v.context();

v.validated_seq_join(stream::iter(inputs).enumerate().map(
|(i, (bit_share, (a_share, b_share)))| {
let m_ctx = m_ctx.clone();
async move {
select(m_ctx, RecordId::from(i), &bit_share, &a_share, &b_share)
.await
}
},
))
.try_collect()
.await
},
)
.await
.map(Result::unwrap);

let ab: Vec<V> = [ab0, ab1, ab2].reconstruct();

for i in 0..count {
assert_eq!(ab[i], if bit[i].into() { a[i] } else { b[i] });
}
}

/// test for testing `validated_seq_join`
/// similar to `complex_circuit` in `validator.rs`
async fn complex_circuit_dzkp(
/// similar to `complex_circuit` in `validator.rs` (which has a more detailed comment)
async fn chained_multiplies_dzkp(
count: usize,
max_multiplications_per_gate: usize,
) -> Result<(), Error> {
Expand Down Expand Up @@ -945,7 +1113,7 @@ mod tests {
.map(|(ctx, input_shares)| async move {
let v = ctx
.set_total_records(count - 1)
.dzkp_validator(TEST_DZKP_STEPS, ctx.active_work().get());
.dzkp_validator(TEST_DZKP_STEPS, max_multiplications_per_gate);
let m_ctx = v.context();

let m_results = v
Expand Down Expand Up @@ -1021,19 +1189,63 @@ mod tests {
Ok(())
}

fn record_count_strategy() -> impl Strategy<Value = usize> {
// The chained_multiplies test has count - 1 records, so 1 is not a valid input size.
// It is for multi_select though.
prop_oneof![2usize..=512, (1u32..=9).prop_map(|i| 1usize << i)]
}

fn max_multiplications_per_gate_strategy(record_count: usize) -> impl Strategy<Value = usize> {
let max_max_mults = record_count.min(128);
prop_oneof![
1usize..=max_max_mults,
(0u32..=max_max_mults.ilog2()).prop_map(|i| 1usize << i)
]
}

prop_compose! {
fn arb_count_and_chunk()((log_count, log_multiplication_amount) in select(&[(5,5),(7,5),(5,8)])) -> (usize, usize) {
(1usize<<log_count, 1usize<<log_multiplication_amount)
fn batching()
(record_count in record_count_strategy())
(record_count in Just(record_count), max_mults in max_multiplications_per_gate_strategy(record_count))
-> (usize, usize)
{
(record_count, max_mults)
}
}

proptest! {
#![proptest_config(ProptestConfig::with_cases(50))]
#[test]
fn test_complex_circuit_dzkp((count, multiplication_amount) in arb_count_and_chunk()){
let future = async {
let _ = complex_circuit_dzkp(count, multiplication_amount).await;
};
tokio::runtime::Runtime::new().unwrap().block_on(future);
fn batching_proptest((record_count, max_multiplications_per_gate) in batching()) {
println!("record_count {record_count} batch {max_multiplications_per_gate}");
if record_count / max_multiplications_per_gate >= 192 {
// TODO: #1269, or even if we don't fix that, don't hardcode the limit.
println!("skipping config because batch count exceeds limit of 192");
}
// This condition is correct only for active_work = 16 and record size of 1 byte.
else if max_multiplications_per_gate != 1 && max_multiplications_per_gate % 16 != 0 {
// TODO: #1300, read_size | batch_size.
// Note: for active work < 2048, read size matches active work.

// Besides read_size | batch_size, there is also a constraint
// something like active_work > read_size + batch_size - 1.
println!("skipping config due to read_size vs. batch_size constraints");
} else {
tokio::runtime::Runtime::new().unwrap().block_on(async {
chained_multiplies_dzkp(record_count, max_multiplications_per_gate).await.unwrap();
/*
multi_select_malicious::<BA3>(record_count, max_multiplications_per_gate).await;
multi_select_malicious::<BA8>(record_count, max_multiplications_per_gate).await;
multi_select_malicious::<BA16>(record_count, max_multiplications_per_gate).await;
*/
multi_select_malicious::<BA20>(record_count, max_multiplications_per_gate).await;
/*
multi_select_malicious::<BA32>(record_count, max_multiplications_per_gate).await;
multi_select_malicious::<BA64>(record_count, max_multiplications_per_gate).await;
multi_select_malicious::<BA256>(record_count, max_multiplications_per_gate).await;
*/
});
}
}
}

Expand Down
Loading

0 comments on commit 59ca3e7

Please sign in to comment.